protocol.py 104 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
6
import json
Zhuohan Li's avatar
Zhuohan Li committed
7
import time
8
from http import HTTPStatus
9
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar
Zhuohan Li's avatar
Zhuohan Li committed
10

11
import regex as re
12
import torch
13
from fastapi import HTTPException, UploadFile
14
from openai.types.chat.chat_completion_audio import (
15
16
17
    ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
18
19
20
21
22
from openai.types.responses import (
    ResponseCodeInterpreterCallCodeDeltaEvent,
    ResponseCodeInterpreterCallCodeDoneEvent,
    ResponseCodeInterpreterCallCompletedEvent,
    ResponseCodeInterpreterCallInProgressEvent,
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    ResponseCodeInterpreterCallInterpretingEvent,
    ResponseContentPartAddedEvent,
    ResponseContentPartDoneEvent,
    ResponseFunctionToolCall,
    ResponseInputItemParam,
    ResponseOutputItem,
    ResponseOutputItemAddedEvent,
    ResponseOutputItemDoneEvent,
    ResponsePrompt,
    ResponseReasoningItem,
    ResponseReasoningTextDeltaEvent,
    ResponseReasoningTextDoneEvent,
    ResponseStatus,
    ResponseWebSearchCallCompletedEvent,
    ResponseWebSearchCallInProgressEvent,
    ResponseWebSearchCallSearchingEvent,
)
40
from openai.types.responses import (
41
42
43
    ResponseCompletedEvent as OpenAIResponseCompletedEvent,
)
from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent
44
from openai.types.responses import (
45
46
    ResponseInProgressEvent as OpenAIResponseInProgressEvent,
)
47
from openai.types.responses.response_reasoning_item import (
48
49
    Content as ResponseReasoningTextContent,
)
50
from openai_harmony import Message as OpenAIHarmonyMessage
51

52
53
54
55
56
57
from vllm.utils.serial_utils import (
    EmbedDType,
    EncodingFormat,
    Endianness,
)

58
59
60
61
# Backward compatibility for OpenAI client versions
try:  # For older openai versions (< 1.100.0)
    from openai.types.responses import ResponseTextConfig
except ImportError:  # For newer openai versions (>= 1.100.0)
62
    from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
63

64

65
from openai.types.responses.response import IncompleteDetails, ToolChoice
66
67
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
68
69
70
71
72
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    TypeAdapter,
73
    ValidationError,
74
    ValidationInfo,
75
    field_serializer,
76
77
78
    field_validator,
    model_validator,
)
Zhuohan Li's avatar
Zhuohan Li committed
79

80
from vllm import envs
81
82
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam
83
from vllm.logger import init_logger
84
from vllm.logprobs import Logprob
85
from vllm.pooling_params import PoolingParams
86
87
88
89
90
91
from vllm.sampling_params import (
    BeamSearchParams,
    RequestOutputKind,
    SamplingParams,
    StructuredOutputsParams,
)
92
93
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
94

95
96
logger = init_logger(__name__)

97
_LONG_INFO = torch.iinfo(torch.long)
98

Zhuohan Li's avatar
Zhuohan Li committed
99

100
class OpenAIBaseModel(BaseModel):
101
102
103
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

104
    # Cache class field names
105
    field_names: ClassVar[set[str] | None] = None
106

107
    @model_validator(mode="wrap")
108
    @classmethod
109
110
111
112
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
113
114
        field_names = cls.field_names
        if field_names is None:
115
116
117
118
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
119
                if alias := getattr(field, "alias", None):
120
121
122
123
124
125
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
126
                "The following fields were present in the request but ignored: %s",
127
128
                data.keys() - field_names,
            )
129
        return result
130
131


132
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
133
134
    message: str
    type: str
135
    param: str | None = None
136
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
137
138


139
140
141
142
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


143
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
144
145
146
147
148
149
150
151
152
153
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
154
    group: str | None = None
155
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
156
157


158
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
159
160
161
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
162
    owned_by: str = "vllm"
163
164
165
    root: str | None = None
    parent: str | None = None
    max_model_len: int | None = None
166
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
167
168


169
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
170
    object: str = "list"
171
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
172
173


174
class PromptTokenUsageInfo(OpenAIBaseModel):
175
    cached_tokens: int | None = None
176
177


178
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
179
180
    prompt_tokens: int = 0
    total_tokens: int = 0
181
182
    completion_tokens: int | None = 0
    prompt_tokens_details: PromptTokenUsageInfo | None = None
Zhuohan Li's avatar
Zhuohan Li committed
183
184


185
186
class RequestResponseMetadata(BaseModel):
    request_id: str
187
    final_usage_info: UsageInfo | None = None
188
189


190
191
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
192
    description: str | None = None
193
194
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
195
196
    json_schema: dict[str, Any] | None = Field(default=None, alias="schema")
    strict: bool | None = None
197
198


199
class LegacyStructuralTag(OpenAIBaseModel):
200
201
202
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
203
    structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema")
204
205
206
    end: str


207
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
208
    type: Literal["structural_tag"]
209
    structures: list[LegacyStructuralTag]
210
211
212
    triggers: list[str]


213
214
215
216
217
218
219
220
221
222
class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    format: Any


AnyStructuralTagResponseFormat: TypeAlias = (
    LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
)


223
class ResponseFormat(OpenAIBaseModel):
224
    # type must be "json_schema", "json_object", or "text"
225
    type: Literal["text", "json_object", "json_schema"]
226
    json_schema: JsonSchemaResponseFormat | None = None
227
228


229
230
231
AnyResponseFormat: TypeAlias = (
    ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
)
232
233


234
class StreamOptions(OpenAIBaseModel):
235
236
    include_usage: bool | None = True
    continuous_usage_stats: bool | None = False
237
238


239
240
class FunctionDefinition(OpenAIBaseModel):
    name: str
241
242
    description: str | None = None
    parameters: dict[str, Any] | None = None
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258


class ChatCompletionToolsParam(OpenAIBaseModel):
    type: Literal["function"] = "function"
    function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
    name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
    function: ChatCompletionNamedFunction
    type: Literal["function"] = "function"


259
260
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
261
262
class LogitsProcessorConstructor(BaseModel):
    qualname: str
263
264
    args: list[Any] | None = None
    kwargs: dict[str, Any] | None = None
265

266
267
    model_config = ConfigDict(extra="forbid")

268

269
LogitsProcessors = list[str | LogitsProcessorConstructor]
270
271


272
def get_logits_processors(
273
274
    processors: LogitsProcessors | None, pattern: str | None
) -> list[Any] | None:
275
276
277
    if processors and pattern:
        logits_processors = []
        for processor in processors:
278
            qualname = processor if isinstance(processor, str) else processor.qualname
279
280
281
282
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
283
284
                    "for more information."
                )
285
286
287
288
289
290
291
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
292
293
294
                logits_processor = logits_processor(
                    *processor.args or [], **processor.kwargs or {}
                )
295
296
297
298
299
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
300
            "server. See --logits-processor-pattern engine argument "
301
302
            "for more information."
        )
303
304
305
    return None


306
307
308
ResponseInputOutputItem: TypeAlias = (
    ResponseInputItemParam | ResponseReasoningItem | ResponseFunctionToolCall
)
309
310


311
312
313
class ResponsesRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/responses/create
314
315
    background: bool | None = False
    include: (
316
317
318
319
320
321
322
323
324
325
        list[
            Literal[
                "code_interpreter_call.outputs",
                "computer_call_output.output.image_url",
                "file_search_call.results",
                "message.input_image.image_url",
                "message.output_text.logprobs",
                "reasoning.encrypted_content",
            ],
        ]
326
327
328
329
330
331
332
333
334
335
336
337
        | None
    ) = None
    input: str | list[ResponseInputOutputItem]
    instructions: str | None = None
    max_output_tokens: int | None = None
    max_tool_calls: int | None = None
    metadata: Metadata | None = None
    model: str | None = None
    parallel_tool_calls: bool | None = True
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
338
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
339
340
341
342
    store: bool | None = True
    stream: bool | None = False
    temperature: float | None = None
    text: ResponseTextConfig | None = None
343
344
    tool_choice: ToolChoice = "auto"
    tools: list[Tool] = Field(default_factory=list)
345
346
347
348
    top_logprobs: int | None = 0
    top_p: float | None = None
    truncation: Literal["auto", "disabled"] | None = "disabled"
    user: str | None = None
349
350
351
352
353
354
355

    # --8<-- [start:responses-extra-params]
    request_id: str = Field(
        default_factory=lambda: f"resp_{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
356
357
            "through out the inference process and return in response."
        ),
358
    )
359
    mm_processor_kwargs: dict[str, Any] | None = Field(
360
361
362
363
364
365
366
367
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
368
369
            "if the served model does not use priority scheduling."
        ),
370
    )
371
    cache_salt: str | None = Field(
372
373
374
375
376
377
378
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
379
380
381
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
382
383
384
385
386

    enable_response_messages: bool = Field(
        default=False,
        description=(
            "Dictates whether or not to return messages as part of the "
387
            "response object. Currently only supported for"
388
389
390
            "non-background and gpt-oss only. "
        ),
    )
391
392
393
394
395
    # similar to input_messages / output_messages in ResponsesResponse
    # we take in previous_input_messages (ie in harmony format)
    # this cannot be used in conjunction with previous_response_id
    # TODO: consider supporting non harmony messages as well
    previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
396
397
398
399
400
401
402
403
404
405
    # --8<-- [end:responses-extra-params]

    _DEFAULT_SAMPLING_PARAMS = {
        "temperature": 1.0,
        "top_p": 1.0,
    }

    def to_sampling_params(
        self,
        default_max_tokens: int,
406
        default_sampling_params: dict | None = None,
407
408
409
410
411
412
413
414
415
    ) -> SamplingParams:
        if self.max_output_tokens is None:
            max_tokens = default_max_tokens
        else:
            max_tokens = min(self.max_output_tokens, default_max_tokens)

        default_sampling_params = default_sampling_params or {}
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
416
417
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
418
419
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
420
421
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
422
        stop_token_ids = default_sampling_params.get("stop_token_ids")
423
424

        # Structured output
425
        structured_outputs = None
426
427
        if self.text is not None and self.text.format is not None:
            response_format = self.text.format
428
429
430
431
            if (
                response_format.type == "json_schema"
                and response_format.schema_ is not None
            ):
432
                structured_outputs = StructuredOutputsParams(
433
434
                    json=response_format.schema_
                )
435
436
437
438
439
440
441
442
            elif response_format.type == "json_object":
                raise NotImplementedError("json_object is not supported")

        # TODO: add more parameters
        return SamplingParams.from_optional(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
443
            logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
444
            stop_token_ids=stop_token_ids,
445
446
447
            output_kind=(
                RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
            ),
448
            structured_outputs=structured_outputs,
449
450
        )

451
452
453
454
    def is_include_output_logprobs(self) -> bool:
        """Check if the request includes output logprobs."""
        if self.include is None:
            return False
455
456
457
458
        return (
            isinstance(self.include, list)
            and "message.output_text.logprobs" in self.include
        )
459

460
461
462
463
464
    @model_validator(mode="before")
    def validate_background(cls, data):
        if not data.get("background"):
            return data
        if not data.get("store", True):
465
            raise ValueError("background can only be used when `store` is true")
466
467
468
469
470
471
472
473
        return data

    @model_validator(mode="before")
    def validate_prompt(cls, data):
        if data.get("prompt") is not None:
            raise ValueError("prompt template is not supported")
        return data

474
475
476
477
478
479
    @model_validator(mode="before")
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
480
481
482
483
484
485
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
486
487
        return data

488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    @model_validator(mode="before")
    def function_call_parsing(cls, data):
        """Parse function_call dictionaries into ResponseFunctionToolCall objects.
        This ensures Pydantic can properly resolve union types in the input field.
        Function calls provided as dicts are converted to ResponseFunctionToolCall
        objects before validation, while invalid structures are left for Pydantic
        to reject with appropriate error messages.
        """

        input_data = data.get("input")

        # Early return for None, strings, or bytes
        # (strings are iterable but shouldn't be processed)
        if input_data is None or isinstance(input_data, (str, bytes)):
            return data

        # Convert iterators (like ValidatorIterator) to list
        if not isinstance(input_data, list):
            try:
                input_data = list(input_data)
            except TypeError:
                # Not iterable, leave as-is for Pydantic to handle
                return data

        processed_input = []
        for item in input_data:
            if isinstance(item, dict) and item.get("type") == "function_call":
                try:
                    processed_input.append(ResponseFunctionToolCall(**item))
                except ValidationError:
                    # Let Pydantic handle validation for malformed function calls
                    logger.debug(
                        "Failed to parse function_call to ResponseFunctionToolCall, "
                        "leaving for Pydantic validation"
                    )
                    processed_input.append(item)
            else:
                processed_input.append(item)

        data["input"] = processed_input
        return data

530

531
class ChatCompletionRequest(OpenAIBaseModel):
532
533
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
534
    messages: list[ChatCompletionMessageParam]
535
536
537
538
539
540
    model: str | None = None
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: bool | None = False
    top_logprobs: int | None = 0
    max_tokens: int | None = Field(
541
        default=None,
542
543
        deprecated="max_tokens is deprecated in favor of "
        "the max_completion_tokens field",
544
    )
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
    max_completion_tokens: int | None = None
    n: int | None = 1
    presence_penalty: float | None = 0.0
    response_format: AnyResponseFormat | None = None
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    temperature: float | None = None
    top_p: float | None = None
    tools: list[ChatCompletionToolsParam] | None = None
    tool_choice: (
        Literal["none"]
        | Literal["auto"]
        | Literal["required"]
        | ChatCompletionNamedToolChoiceParam
        | None
    ) = "none"
    reasoning_effort: Literal["low", "medium", "high"] | None = None
564
    include_reasoning: bool = True
565

566
    # NOTE this will be ignored by vLLM -- the model determines the behavior
567
568
    parallel_tool_calls: bool | None = False
    user: str | None = None
569

570
    # --8<-- [start:chat-completion-sampling-params]
571
    best_of: int | None = None
572
    use_beam_search: bool = False
573
574
575
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
576
    length_penalty: float = 1.0
577
    stop_token_ids: list[int] | None = []
578
579
580
581
582
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
583
584
585
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    prompt_logprobs: int | None = None
    allowed_token_ids: list[int] | None = None
586
    bad_words: list[str] = Field(default_factory=list)
587
    # --8<-- [end:chat-completion-sampling-params]
588

589
    # --8<-- [start:chat-completion-extra-params]
590
    echo: bool = Field(
591
592
593
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
594
595
            "if they belong to the same role."
        ),
596
    )
597
    add_generation_prompt: bool = Field(
598
        default=True,
599
600
601
602
603
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
604
    )
605
606
    continue_final_message: bool = Field(
        default=False,
607
608
609
610
611
612
613
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
614
    )
615
    add_special_tokens: bool = Field(
616
617
618
619
620
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
621
            "special tokens so this should be set to false (as is the "
622
623
            "default)."
        ),
624
    )
625
    documents: list[dict[str, str]] | None = Field(
626
        default=None,
627
628
629
630
631
632
633
        description=(
            "A list of dicts representing documents that will be accessible to "
            "the model if it is performing RAG (retrieval-augmented generation)."
            " If the template does not support RAG, this argument will have no "
            "effect. We recommend that each document should be a dict containing "
            '"title" and "text" keys.'
        ),
634
    )
635
    chat_template: str | None = Field(
636
637
638
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
639
640
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
641
642
            "does not define one."
        ),
643
    )
644
    chat_template_kwargs: dict[str, Any] | None = Field(
645
        default=None,
646
647
        description=(
            "Additional keyword args to pass to the template renderer. "
648
649
            "Will be accessible by the chat template."
        ),
650
    )
651
    mm_processor_kwargs: dict[str, Any] | None = Field(
652
653
654
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
655
    structured_outputs: StructuredOutputsParams | None = Field(
656
        default=None,
657
        description="Additional kwargs for structured outputs",
658
    )
659
    guided_json: str | dict | BaseModel | None = Field(
660
661
662
663
        default=None,
        description=(
            "`guided_json` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
664
665
            "Please pass `json` to `structured_outputs` instead."
        ),
666
    )
667
    guided_regex: str | None = Field(
668
669
670
671
        default=None,
        description=(
            "`guided_regex` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
672
673
            "Please pass `regex` to `structured_outputs` instead."
        ),
674
    )
675
    guided_choice: list[str] | None = Field(
676
677
678
679
        default=None,
        description=(
            "`guided_choice` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
680
681
            "Please pass `choice` to `structured_outputs` instead."
        ),
682
    )
683
    guided_grammar: str | None = Field(
684
685
686
687
        default=None,
        description=(
            "`guided_grammar` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
688
689
            "Please pass `grammar` to `structured_outputs` instead."
        ),
690
    )
691
    structural_tag: str | None = Field(
692
693
694
695
        default=None,
        description=(
            "`structural_tag` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
696
697
            "Please pass `structural_tag` to `structured_outputs` instead."
        ),
698
    )
699
    guided_decoding_backend: str | None = Field(
700
701
702
703
        default=None,
        description=(
            "`guided_decoding_backend` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
704
705
            "Please remove it from your request."
        ),
706
    )
707
    guided_whitespace_pattern: str | None = Field(
708
709
710
711
712
713
714
        default=None,
        description=(
            "`guided_whitespace_pattern` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
            "Please pass `whitespace_pattern` to `structured_outputs` instead."
        ),
    )
715
716
717
718
719
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
720
721
            "if the served model does not use priority scheduling."
        ),
722
    )
723
724
725
726
727
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
728
729
            "through out the inference process and return in response."
        ),
730
    )
731
    logits_processors: LogitsProcessors | None = Field(
732
733
734
735
736
737
738
739
740
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
741
742
743
            "{'param': 'value'}}."
        ),
    )
744
    return_tokens_as_token_ids: bool | None = Field(
745
746
747
748
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
749
750
751
            "that are not JSON-encodable can be identified."
        ),
    )
752
    return_token_ids: bool | None = Field(
753
754
755
756
757
758
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
759
760
761
            "need to map generated text back to input tokens."
        ),
    )
762
    cache_salt: str | None = Field(
763
764
765
766
767
768
769
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
770
771
772
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
773
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
774
        default=None,
775
776
        description="KVTransfer parameters used for disaggregated serving.",
    )
777

778
    vllm_xargs: dict[str, str | int | float] | None = Field(
779
        default=None,
780
781
782
783
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
784
785
    )

786
    # --8<-- [end:chat-completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
787

788
789
790
791
792
    # Default sampling parameters for chat completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
793
        "top_k": 0,
794
795
796
797
        "min_p": 0.0,
    }

    def to_beam_search_params(
798
799
        self, max_tokens: int, default_sampling_params: dict
    ) -> BeamSearchParams:
800
        n = self.n if self.n is not None else 1
801
802
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
803
804
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
805
806
807
808
809
810

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
811
            length_penalty=self.length_penalty,
812
813
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
814

815
    def to_sampling_params(
816
        self,
817
        max_tokens: int,
818
        logits_processor_pattern: str | None,
819
        default_sampling_params: dict,
820
    ) -> SamplingParams:
821
822
823
824
825
826
827
828
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
829
830
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
831
832
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
833
834
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
835
836
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
837
838
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
839
840
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
841
842
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
843

844
845
846
847
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.top_logprobs

848
849
850
851
852
853
854
855
856
857
858
859
860
861
        # Forward deprecated guided_* parameters to structured_outputs
        if self.structured_outputs is None:
            kwargs = dict[str, Any](
                json=self.guided_json,
                regex=self.guided_regex,
                choice=self.guided_choice,
                grammar=self.guided_grammar,
                whitespace_pattern=self.guided_whitespace_pattern,
                structural_tag=self.structural_tag,
            )
            kwargs = {k: v for k, v in kwargs.items() if v is not None}
            if len(kwargs) > 0:
                self.structured_outputs = StructuredOutputsParams(**kwargs)

862
        response_format = self.response_format
863
        if response_format is not None:
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
            if response_format is not None:
                if response_format.type == "json_object":
                    self.structured_outputs.json_object = True
                elif response_format.type == "json_schema":
                    json_schema = response_format.json_schema
                    assert json_schema is not None
                    self.structured_outputs.json = json_schema.json_schema
                elif response_format.type == "structural_tag":
                    structural_tag = response_format
                    assert structural_tag is not None and isinstance(
880
881
882
883
884
                        structural_tag,
                        (
                            LegacyStructuralTagResponseFormat,
                            StructuralTagResponseFormat,
                        ),
885
                    )
886
                    s_tag_obj = structural_tag.model_dump(by_alias=True)
887
                    self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
888

889
890
891
892
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
893
        return SamplingParams.from_optional(
894
            n=self.n,
895
            best_of=self.best_of,
896
897
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
898
899
900
901
902
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
903
            seed=self.seed,
904
905
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
906
            logprobs=self.top_logprobs if self.logprobs else None,
907
            prompt_logprobs=prompt_logprobs,
908
            ignore_eos=self.ignore_eos,
909
            max_tokens=max_tokens,
910
            min_tokens=self.min_tokens,
911
912
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
913
914
915
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
916
            include_stop_str_in_output=self.include_stop_str_in_output,
917
            truncate_prompt_tokens=self.truncate_prompt_tokens,
918
919
920
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
921
            structured_outputs=self.structured_outputs,
Robert Shaw's avatar
Robert Shaw committed
922
            logit_bias=self.logit_bias,
923
            bad_words=self.bad_words,
924
            allowed_token_ids=self.allowed_token_ids,
925
926
            extra_args=extra_args or None,
        )
927

928
    @model_validator(mode="before")
929
    @classmethod
930
931
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
932
            raise ValueError("Stream options can only be defined when `stream=True`.")
933
934
935
936
937
938
939

        return data

    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
940
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
941
                raise ValueError(
942
943
                    "`prompt_logprobs` are not available when `stream=True`."
                )
944

945
            if prompt_logprobs < 0 and prompt_logprobs != -1:
946
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
947
            if prompt_logprobs == -1 and not envs.VLLM_USE_V1:
948
949
950
                raise ValueError(
                    "`prompt_logprobs=-1` is only supported with vLLM engine V1."
                )
951
        if (top_logprobs := data.get("top_logprobs")) is not None:
952
            if top_logprobs < 0 and top_logprobs != -1:
953
                raise ValueError("`top_logprobs` must be a positive value or -1.")
954

955
            if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
956
957
958
959
960
                raise ValueError(
                    "when using `top_logprobs`, `logprobs` must be set to true."
                )

        return data
961

962
963
    @model_validator(mode="before")
    @classmethod
964
    def check_structured_outputs_count(cls, data):
965
966
967
        if isinstance(data, ValueError):
            raise data

968
        if data.get("structured_outputs", None) is None:
969
970
            return data

971
        structured_outputs_kwargs = data["structured_outputs"]
972
973
        count = sum(
            structured_outputs_kwargs.get(k) is not None
974
975
            for k in ("json", "regex", "choice")
        )
976
977
        # you can only use one kind of constraints for structured outputs
        if count > 1:
978
            raise ValueError(
979
                "You can only use one kind of constraints for structured "
980
981
                "outputs ('json', 'regex' or 'choice')."
            )
982
983
        # you can only either use structured outputs or tools, not both
        if count > 1 and data.get("tool_choice", "none") not in (
984
985
986
            "none",
            "auto",
            "required",
987
        ):
988
            raise ValueError(
989
                "You can only either use constraints for structured outputs "
990
991
                "or tools, not both."
            )
992
993
994
995
        return data

    @model_validator(mode="before")
    @classmethod
996
997
998
    def check_tool_usage(cls, data):
        # if "tool_choice" is not specified but tools are provided,
        # default to "auto" tool_choice
999
        if "tool_choice" not in data and data.get("tools"):
1000
1001
            data["tool_choice"] = "auto"

1002
        # if "tool_choice" is "none" -- no validation is needed for tools
1003
1004
1005
        if "tool_choice" in data and data["tool_choice"] == "none":
            return data

1006
        # if "tool_choice" is specified -- validation
1007
        if "tool_choice" in data and data["tool_choice"] is not None:
1008
            # ensure that if "tool choice" is specified, tools are present
1009
            if "tools" not in data or data["tools"] is None:
1010
                raise ValueError("When using `tool_choice`, `tools` must be set.")
1011
1012

            # make sure that tool choice is either a named tool
1013
            # OR that it's set to "auto" or "required"
1014
1015
1016
            if data["tool_choice"] not in ["auto", "required"] and not isinstance(
                data["tool_choice"], dict
            ):
1017
                raise ValueError(
1018
1019
1020
                    f"Invalid value for `tool_choice`: {data['tool_choice']}! "
                    'Only named tools, "none", "auto" or "required" '
                    "are supported."
1021
                )
1022

1023
1024
1025
            # if tool_choice is "required" but the "tools" list is empty,
            # override the data to behave like "none" to align with
            # OpenAI’s behavior.
1026
1027
1028
1029
1030
            if (
                data["tool_choice"] == "required"
                and isinstance(data["tools"], list)
                and len(data["tools"]) == 0
            ):
1031
1032
1033
1034
                data["tool_choice"] = "none"
                del data["tools"]
                return data

1035
1036
            # ensure that if "tool_choice" is specified as an object,
            # it matches a valid tool
1037
1038
            correct_usage_message = (
                'Correct usage: `{"type": "function",'
1039
                ' "function": {"name": "my_function"}}`'
1040
            )
1041
1042
            if isinstance(data["tool_choice"], dict):
                valid_tool = False
1043
1044
                function = data["tool_choice"].get("function")
                if not isinstance(function, dict):
1045
                    raise ValueError(
1046
                        f"Invalid value for `function`: `{function}` in "
1047
1048
                        f"`tool_choice`! {correct_usage_message}"
                    )
1049
                if "name" not in function:
1050
1051
1052
1053
                    raise ValueError(
                        f"Expected field `name` in `function` in "
                        f"`tool_choice`! {correct_usage_message}"
                    )
1054
                function_name = function["name"]
1055
                if not isinstance(function_name, str) or len(function_name) == 0:
1056
                    raise ValueError(
1057
                        f"Invalid `name` in `function`: `{function_name}`"
1058
1059
                        f" in `tool_choice`! {correct_usage_message}"
                    )
1060
                for tool in data["tools"]:
1061
                    if tool["function"]["name"] == function_name:
1062
1063
1064
1065
1066
                        valid_tool = True
                        break
                if not valid_tool:
                    raise ValueError(
                        "The tool specified in `tool_choice` does not match any"
1067
1068
                        " of the specified `tools`"
                    )
1069
1070
        return data

1071
1072
1073
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1074
1075
1076
1077
1078
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1079
1080
        return data

1081
1082
1083
1084
1085
1086
1087
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
1088
1089
1090
1091
1092
1093
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
1094
1095
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1096

1097
class CompletionRequest(OpenAIBaseModel):
1098
1099
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
1100
1101
1102
1103
1104
1105
1106
1107
    model: str | None = None
    prompt: list[int] | list[list[int]] | str | list[str] | None = None
    best_of: int | None = None
    echo: bool | None = False
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: int | None = None
    max_tokens: int | None = 16
1108
    n: int = 1
1109
1110
1111
1112
1113
1114
1115
1116
1117
    presence_penalty: float | None = 0.0
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    suffix: str | None = None
    temperature: float | None = None
    top_p: float | None = None
    user: str | None = None
1118

1119
    # --8<-- [start:completion-sampling-params]
1120
    use_beam_search: bool = False
1121
1122
1123
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
1124
    length_penalty: float = 1.0
1125
    stop_token_ids: list[int] | None = []
1126
1127
1128
1129
1130
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
1131
1132
1133
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    allowed_token_ids: list[int] | None = None
    prompt_logprobs: int | None = None
1134
    # --8<-- [end:completion-sampling-params]
1135

1136
    # --8<-- [start:completion-extra-params]
1137
    prompt_embeds: bytes | list[bytes] | None = None
1138
1139
    add_special_tokens: bool = Field(
        default=True,
1140
        description=(
1141
            "If true (the default), special tokens (e.g. BOS) will be added to "
1142
1143
            "the prompt."
        ),
1144
    )
1145
    response_format: AnyResponseFormat | None = Field(
1146
        default=None,
1147
1148
1149
1150
1151
        description=(
            "Similar to chat completion, this parameter specifies the format "
            "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
            ", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
        ),
1152
    )
1153
    structured_outputs: StructuredOutputsParams | None = Field(
1154
        default=None,
1155
        description="Additional kwargs for structured outputs",
1156
    )
1157
    guided_json: str | dict | BaseModel | None = Field(
1158
1159
1160
1161
        default=None,
        description=(
            "`guided_json` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1162
1163
            "Please pass `json` to `structured_outputs` instead."
        ),
1164
    )
1165
    guided_regex: str | None = Field(
1166
1167
1168
1169
        default=None,
        description=(
            "`guided_regex` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1170
1171
            "Please pass `regex` to `structured_outputs` instead."
        ),
1172
    )
1173
    guided_choice: list[str] | None = Field(
1174
1175
1176
1177
        default=None,
        description=(
            "`guided_choice` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1178
1179
            "Please pass `choice` to `structured_outputs` instead."
        ),
1180
    )
1181
    guided_grammar: str | None = Field(
1182
1183
1184
1185
        default=None,
        description=(
            "`guided_grammar` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1186
1187
            "Please pass `grammar` to `structured_outputs` instead."
        ),
1188
    )
1189
1190
1191
1192
    structural_tag: str | None = Field(
        default=None,
        description=("If specified, the output will follow the structural tag schema."),
    )
1193
    guided_decoding_backend: str | None = Field(
1194
1195
1196
1197
        default=None,
        description=(
            "`guided_decoding_backend` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1198
1199
            "Please remove it from your request."
        ),
1200
    )
1201
    guided_whitespace_pattern: str | None = Field(
1202
1203
1204
1205
1206
1207
1208
        default=None,
        description=(
            "`guided_whitespace_pattern` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
            "Please pass `whitespace_pattern` to `structured_outputs` instead."
        ),
    )
1209
1210
1211
1212
1213
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1214
1215
            "if the served model does not use priority scheduling."
        ),
1216
    )
1217
1218
1219
1220
1221
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1222
1223
            "through out the inference process and return in response."
        ),
1224
    )
1225
    logits_processors: LogitsProcessors | None = Field(
1226
1227
1228
1229
1230
1231
1232
1233
1234
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
1235
1236
1237
            "{'param': 'value'}}."
        ),
    )
1238

1239
    return_tokens_as_token_ids: bool | None = Field(
1240
1241
1242
1243
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
1244
1245
1246
            "that are not JSON-encodable can be identified."
        ),
    )
1247
    return_token_ids: bool | None = Field(
1248
1249
1250
1251
1252
1253
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
1254
1255
1256
            "need to map generated text back to input tokens."
        ),
    )
1257

1258
    cache_salt: str | None = Field(
1259
1260
1261
1262
1263
1264
1265
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
1266
1267
1268
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
1269

1270
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
1271
        default=None,
1272
1273
        description="KVTransfer parameters used for disaggregated serving.",
    )
Robert Shaw's avatar
Robert Shaw committed
1274

1275
    vllm_xargs: dict[str, str | int | float] | None = Field(
1276
        default=None,
1277
1278
1279
1280
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
1281
1282
    )

1283
    # --8<-- [end:completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
1284

1285
1286
1287
1288
1289
    # Default sampling parameters for completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
1290
        "top_k": 0,
1291
1292
1293
1294
        "min_p": 0.0,
    }

    def to_beam_search_params(
1295
1296
        self,
        max_tokens: int,
1297
        default_sampling_params: dict | None = None,
1298
1299
1300
    ) -> BeamSearchParams:
        if default_sampling_params is None:
            default_sampling_params = {}
1301
        n = self.n if self.n is not None else 1
1302
1303
1304

        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get("temperature", 1.0)
1305
1306
1307
1308
1309
1310

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
1311
            length_penalty=self.length_penalty,
1312
1313
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
1314

1315
    def to_sampling_params(
1316
        self,
1317
        max_tokens: int,
1318
1319
        logits_processor_pattern: str | None,
        default_sampling_params: dict | None = None,
1320
    ) -> SamplingParams:
1321
1322
        if default_sampling_params is None:
            default_sampling_params = {}
1323

1324
1325
1326
1327
1328
1329
1330
1331
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
1332
1333
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
1334
1335
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
1336
1337
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
1338
1339
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
1340
1341
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
1342
1343
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
1344
1345
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
1346

1347
1348
1349
1350
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

1351
1352
        echo_without_generation = self.echo and self.max_tokens == 0

1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
        guided_json_object = None
        if self.response_format is not None:
            if self.response_format.type == "json_object":
                guided_json_object = True
            elif self.response_format.type == "json_schema":
                json_schema = self.response_format.json_schema
                assert json_schema is not None
                self.guided_json = json_schema.json_schema
            elif self.response_format.type == "structural_tag":
                structural_tag = self.response_format
                assert structural_tag is not None and isinstance(
                    structural_tag, StructuralTagResponseFormat
                )
                s_tag_obj = structural_tag.model_dump(by_alias=True)
                self.structural_tag = json.dumps(s_tag_obj)

1369
1370
1371
1372
        # Forward deprecated guided_* parameters to structured_outputs
        if self.structured_outputs is None:
            kwargs = dict[str, Any](
                json=self.guided_json,
1373
                json_object=guided_json_object,
1374
1375
1376
1377
1378
1379
1380
1381
1382
                regex=self.guided_regex,
                choice=self.guided_choice,
                grammar=self.guided_grammar,
                whitespace_pattern=self.guided_whitespace_pattern,
            )
            kwargs = {k: v for k, v in kwargs.items() if v is not None}
            if len(kwargs) > 0:
                self.structured_outputs = StructuredOutputsParams(**kwargs)

1383
1384
1385
1386
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
1387
        return SamplingParams.from_optional(
1388
            n=self.n,
1389
            best_of=self.best_of,
1390
1391
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
1392
1393
1394
1395
1396
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
1397
            seed=self.seed,
1398
1399
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
1400
            logprobs=self.logprobs,
1401
            ignore_eos=self.ignore_eos,
1402
            max_tokens=max_tokens if not echo_without_generation else 1,
1403
            min_tokens=self.min_tokens,
1404
            prompt_logprobs=prompt_logprobs,
1405
            skip_special_tokens=self.skip_special_tokens,
1406
            spaces_between_special_tokens=self.spaces_between_special_tokens,
1407
            include_stop_str_in_output=self.include_stop_str_in_output,
1408
1409
1410
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
1411
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1412
1413
1414
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
1415
            structured_outputs=self.structured_outputs,
1416
            logit_bias=self.logit_bias,
Robert Shaw's avatar
Robert Shaw committed
1417
            allowed_token_ids=self.allowed_token_ids,
1418
            extra_args=extra_args or None,
1419
        )
1420

1421
1422
    @model_validator(mode="before")
    @classmethod
1423
    def check_structured_outputs_count(cls, data):
1424
        if data.get("structured_outputs", None) is None:
1425
1426
            return data

1427
        structured_outputs_kwargs = data["structured_outputs"]
1428
1429
        count = sum(
            structured_outputs_kwargs.get(k) is not None
1430
1431
            for k in ("json", "regex", "choice")
        )
1432
        if count > 1:
1433
            raise ValueError(
1434
                "You can only use one kind of constraints for structured "
1435
1436
                "outputs ('json', 'regex' or 'choice')."
            )
1437
1438
        return data

1439
1440
1441
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
1442
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
1443
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
1444
                raise ValueError(
1445
1446
                    "`prompt_logprobs` are not available when `stream=True`."
                )
1447

1448
            if prompt_logprobs < 0 and prompt_logprobs != -1:
1449
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
1450
            if prompt_logprobs == -1 and not envs.VLLM_USE_V1:
1451
1452
1453
                raise ValueError(
                    "`prompt_logprobs=-1` is only supported with vLLM engine V1."
                )
1454
1455
1456
        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
            raise ValueError("`logprobs` must be a positive value.")

1457
1458
        return data

1459
1460
1461
1462
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
1463
            raise ValueError("Stream options can only be defined when `stream=True`.")
1464

1465
1466
        return data

1467
1468
1469
    @model_validator(mode="before")
    @classmethod
    def validate_prompt_and_prompt_embeds(cls, data):
1470
1471
1472
        prompt = data.get("prompt")
        prompt_embeds = data.get("prompt_embeds")

1473
1474
1475
1476
        prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
        embeds_is_empty = prompt_embeds is None or (
            isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
        )
1477
1478

        if prompt_is_empty and embeds_is_empty:
1479
            raise ValueError(
1480
1481
1482
                "Either prompt or prompt_embeds must be provided and non-empty."
            )

1483
1484
        return data

1485
1486
1487
1488
1489
1490
1491
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
1492
1493
1494
1495
1496
1497
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
1498
1499
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1500

1501
class EmbeddingCompletionRequest(OpenAIBaseModel):
1502
1503
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/embeddings
1504
1505
    model: str | None = None
    input: list[int] | list[list[int]] | str | list[str]
1506
    encoding_format: EncodingFormat = "float"
1507
1508
1509
    dimensions: int | None = None
    user: str | None = None
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
1510

1511
    # --8<-- [start:embedding-extra-params]
1512
1513
1514
1515
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
1516
1517
            "the prompt."
        ),
1518
    )
1519
1520
1521
1522
1523
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1524
1525
            "if the served model does not use priority scheduling."
        ),
1526
    )
1527
1528
1529
1530
1531
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1532
1533
            "through out the inference process and return in response."
        ),
1534
    )
1535
1536
1537
1538
    normalize: bool | None = Field(
        default=None,
        description="Whether to normalize the embeddings outputs. Default is True.",
    )
1539
    embed_dtype: EmbedDType = Field(
1540
1541
        default="float32",
        description=(
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
            "What dtype to use for encoding. Default to using float32 for base64 "
            "encoding to match the OpenAI python client behavior. "
            "This parameter will affect base64 and binary_response."
        ),
    )
    endianness: Endianness = Field(
        default="native",
        description=(
            "What endianness to use for encoding. Default to using native for "
            "base64 encoding to match the OpenAI python client behavior."
            "This parameter will affect base64 and binary_response."
1553
1554
        ),
    )
1555
    # --8<-- [end:embedding-extra-params]
1556

1557
    def to_pooling_params(self):
1558
1559
1560
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            dimensions=self.dimensions,
1561
1562
            normalize=self.normalize,
        )
1563
1564


1565
class EmbeddingChatRequest(OpenAIBaseModel):
1566
    model: str | None = None
1567
    messages: list[ChatCompletionMessageParam]
1568

1569
    encoding_format: EncodingFormat = "float"
1570
1571
1572
    dimensions: int | None = None
    user: str | None = None
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
1573

1574
    # --8<-- [start:chat-embedding-extra-params]
1575
1576
    add_generation_prompt: bool = Field(
        default=False,
1577
1578
1579
1580
1581
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
1582
1583
    )

1584
1585
1586
1587
1588
1589
1590
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
1591
1592
            "default)."
        ),
1593
    )
1594
    chat_template: str | None = Field(
1595
1596
1597
1598
1599
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1600
1601
            "does not define one."
        ),
1602
    )
1603
    chat_template_kwargs: dict[str, Any] | None = Field(
1604
        default=None,
1605
1606
        description=(
            "Additional keyword args to pass to the template renderer. "
1607
1608
            "Will be accessible by the chat template."
        ),
1609
    )
1610
    mm_processor_kwargs: dict[str, Any] | None = Field(
1611
1612
1613
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
1614
1615
1616
1617
1618
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1619
1620
            "if the served model does not use priority scheduling."
        ),
1621
    )
1622
1623
1624
1625
1626
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1627
1628
            "through out the inference process and return in response."
        ),
1629
    )
1630
1631
1632
1633
    normalize: bool | None = Field(
        default=None,
        description="Whether to normalize the embeddings outputs. Default is True.",
    )
1634
    embed_dtype: EmbedDType = Field(
1635
1636
        default="float32",
        description=(
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
            "What dtype to use for encoding. Default to using float32 for base64 "
            "encoding to match the OpenAI python client behavior. "
            "This parameter will affect base64 and binary_response."
        ),
    )
    endianness: Endianness = Field(
        default="native",
        description=(
            "What endianness to use for encoding. Default to using native for "
            "base64 encoding to match the OpenAI python client behavior."
            "This parameter will affect base64 and binary_response."
1648
1649
        ),
    )
1650
    # --8<-- [end:chat-embedding-extra-params]
1651
1652
1653
1654

    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1655
1656
1657
1658
1659
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1660
1661
1662
        return data

    def to_pooling_params(self):
1663
1664
1665
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            dimensions=self.dimensions,
1666
1667
            normalize=self.normalize,
        )
1668
1669


1670
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
1671

1672
1673
PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest
1674
1675
1676
1677
1678

T = TypeVar("T")


class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
1679
    model: str | None = None
1680
1681
1682
1683
1684
1685
1686
1687
1688

    priority: int = Field(default=0)
    """
    The priority of the request (lower means earlier handling;
    default: 0). Any priority other than 0 will raise an error
    if the served model does not use priority scheduling.
    """
    data: T

1689
1690
    encoding_format: EncodingFormat = "float"
    embed_dtype: EmbedDType = Field(
1691
1692
        default="float32",
        description=(
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
            "What dtype to use for encoding. Default to using float32 for base64 "
            "encoding to match the OpenAI python client behavior. "
            "This parameter will affect base64 and binary_response."
        ),
    )
    endianness: Endianness = Field(
        default="native",
        description=(
            "What endianness to use for encoding. Default to using native for "
            "base64 encoding to match the OpenAI python client behavior."
            "This parameter will affect base64 and binary_response."
1704
1705
1706
        ),
    )

1707
    def to_pooling_params(self):
1708
        return PoolingParams()
1709
1710
1711


class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
1712
    request_id: str | None = None
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
    """
    The request_id associated with this response
    """
    created_at: int = Field(default_factory=lambda: int(time.time()))

    data: T
    """
    When using plugins IOProcessor plugins, the actual output is generated
    by the plugin itself. Hence, we use a generic type for the response data
    """


1725
1726
1727
PoolingRequest: TypeAlias = (
    PoolingCompletionRequest | PoolingChatRequest | IOProcessorRequest
)
1728

1729

1730
class ScoreRequest(OpenAIBaseModel):
1731
1732
1733
1734
    model: str | None = None
    text_1: list[str] | str | ScoreMultiModalParam
    text_2: list[str] | str | ScoreMultiModalParam
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
1735

1736
    # --8<-- [start:score-extra-params]
1737

1738
    mm_processor_kwargs: dict[str, Any] | None = Field(
1739
1740
1741
1742
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1743
1744
1745
1746
1747
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1748
1749
            "if the served model does not use priority scheduling."
        ),
1750
    )
1751

1752
    activation: bool | None = None
1753

1754
    # --8<-- [end:score-extra-params]
1755

1756
    def to_pooling_params(self):
1757
1758
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1759
1760
            activation=self.activation,
        )
1761
1762


1763
class RerankRequest(OpenAIBaseModel):
1764
1765
1766
    model: str | None = None
    query: str | ScoreMultiModalParam
    documents: list[str] | ScoreMultiModalParam
1767
    top_n: int = Field(default_factory=lambda: 0)
1768
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
1769

1770
    # --8<-- [start:rerank-extra-params]
1771

1772
    mm_processor_kwargs: dict[str, Any] | None = Field(
1773
1774
1775
1776
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1777
1778
1779
1780
1781
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1782
1783
            "if the served model does not use priority scheduling."
        ),
1784
    )
1785

1786
    activation: bool | None = None
1787

1788
    # --8<-- [end:rerank-extra-params]
1789

1790
    def to_pooling_params(self):
1791
1792
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1793
1794
            activation=self.activation,
        )
1795
1796
1797


class RerankDocument(BaseModel):
1798
1799
    text: str | None = None
    multi_modal: ScoreContentPartParam | None = None
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815


class RerankResult(BaseModel):
    index: int
    document: RerankDocument
    relevance_score: float


class RerankUsage(BaseModel):
    total_tokens: int


class RerankResponse(OpenAIBaseModel):
    id: str
    model: str
    usage: RerankUsage
1816
    results: list[RerankResult]
1817
1818


1819
class CompletionLogProbs(OpenAIBaseModel):
1820
    text_offset: list[int] = Field(default_factory=list)
1821
    token_logprobs: list[float | None] = Field(default_factory=list)
1822
    tokens: list[str] = Field(default_factory=list)
1823
    top_logprobs: list[dict[str, float] | None] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
1824
1825


1826
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1827
1828
    index: int
    text: str
1829
1830
1831
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1832
1833
1834
1835
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1836
1837
            "including encountering the EOS token"
        ),
1838
    )
1839
1840
1841
    token_ids: list[int] | None = None  # For response
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None  # For prompt
Zhuohan Li's avatar
Zhuohan Li committed
1842
1843


1844
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1845
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
1846
    object: Literal["text_completion"] = "text_completion"
Zhuohan Li's avatar
Zhuohan Li committed
1847
1848
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1849
    choices: list[CompletionResponseChoice]
1850
1851
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1852
    usage: UsageInfo
1853
1854

    # vLLM-specific fields that are not in OpenAI spec
1855
    kv_transfer_params: dict[str, Any] | None = Field(
1856
1857
        default=None, description="KVTransfer parameters."
    )
Zhuohan Li's avatar
Zhuohan Li committed
1858
1859


1860
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1861
1862
    index: int
    text: str
1863
1864
1865
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1866
1867
1868
1869
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1870
1871
            "including encountering the EOS token"
        ),
1872
    )
1873
1874
    # not part of the OpenAI spec but for tracing the tokens
    # prompt tokens is put into choice to align with CompletionResponseChoice
1875
1876
    prompt_token_ids: list[int] | None = None
    token_ids: list[int] | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1877
1878


1879
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1880
1881
1882
1883
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1884
    choices: list[CompletionResponseStreamChoice]
1885
    usage: UsageInfo | None = Field(default=None)
1886
1887


1888
class EmbeddingResponseData(OpenAIBaseModel):
1889
1890
    index: int
    object: str = "embedding"
1891
    embedding: list[float] | str
1892
1893


1894
class EmbeddingResponse(OpenAIBaseModel):
1895
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
1896
1897
1898
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1899
    data: list[EmbeddingResponseData]
1900
1901
1902
    usage: UsageInfo


1903
1904
1905
1906
1907
1908
class EmbeddingBytesResponse(OpenAIBaseModel):
    body: list[bytes]
    metadata: str
    media_type: str = "application/octet-stream"


1909
1910
1911
class PoolingResponseData(OpenAIBaseModel):
    index: int
    object: str = "pooling"
1912
    data: list[list[float]] | list[float] | str
1913
1914
1915
1916
1917
1918
1919


class PoolingResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1920
    data: list[PoolingResponseData]
1921
1922
1923
    usage: UsageInfo


1924
1925
1926
1927
1928
1929
class PoolingBytesResponse(OpenAIBaseModel):
    body: list[bytes]
    metadata: str
    media_type: str = "application/octet-stream"


1930
1931
1932
class ScoreResponseData(OpenAIBaseModel):
    index: int
    object: str = "score"
1933
    score: float
1934
1935
1936
1937
1938
1939
1940


class ScoreResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1941
    data: list[ScoreResponseData]
1942
1943
1944
    usage: UsageInfo


1945
class ClassificationRequest(OpenAIBaseModel):
1946
1947
1948
1949
    model: str | None = None
    input: list[str] | str
    truncate_prompt_tokens: int | None = None
    user: str | None = None
1950

1951
    # --8<-- [start:classification-extra-params]
1952
1953
1954
1955
1956
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1957
1958
            "if the served model does not use priority scheduling."
        ),
1959
1960
    )

1961
    activation: bool | None = None
1962

1963
    # --8<-- [end:classification-extra-params]
1964
1965

    def to_pooling_params(self):
1966
1967
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1968
1969
            activation=self.activation,
        )
1970
1971
1972
1973


class ClassificationData(OpenAIBaseModel):
    index: int
1974
    label: str | None
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
    probs: list[float]
    num_classes: int


class ClassificationResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[ClassificationData]
    usage: UsageInfo


1988
1989
1990
1991
1992
1993
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
1994
    id: str = Field(default_factory=make_tool_call_id)
1995
1996
1997
1998
    type: Literal["function"] = "function"
    function: FunctionCall


1999
class DeltaFunctionCall(BaseModel):
2000
2001
    name: str | None = None
    arguments: str | None = None
2002
2003
2004
2005


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
2006
2007
    id: str | None = None
    type: Literal["function"] | None = None
2008
    index: int
2009
    function: DeltaFunctionCall | None = None
2010
2011
2012
2013
2014
2015
2016


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
2017
    tool_calls: list[ToolCall]
2018
2019
2020

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
2021
    content: str | None = None
2022
2023


2024
class ChatMessage(OpenAIBaseModel):
2025
    role: str
2026
2027
2028
2029
2030
    content: str | None = None
    refusal: str | None = None
    annotations: OpenAIAnnotation | None = None
    audio: OpenAIChatCompletionAudio | None = None
    function_call: FunctionCall | None = None
2031
    tool_calls: list[ToolCall] = Field(default_factory=list)
2032

2033
    # vLLM-specific fields that are not in OpenAI spec
2034
    reasoning_content: str | None = None
2035

2036

2037
2038
2039
class ChatCompletionLogProb(OpenAIBaseModel):
    token: str
    logprob: float = -9999.0
2040
    bytes: list[int] | None = None
2041
2042
2043


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
2044
2045
    # Workaround: redefine fields name cache so that it's not
    # shared with the super class.
2046
    field_names: ClassVar[set[str] | None] = None
2047
    top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
2048
2049
2050


class ChatCompletionLogProbs(OpenAIBaseModel):
2051
    content: list[ChatCompletionLogProbsContent] | None = None
2052
2053


2054
class ChatCompletionResponseChoice(OpenAIBaseModel):
2055
2056
    index: int
    message: ChatMessage
2057
    logprobs: ChatCompletionLogProbs | None = None
2058
    # per OpenAI spec this is the default
2059
    finish_reason: str | None = "stop"
2060
    # not part of the OpenAI spec but included in vLLM for legacy reasons
2061
    stop_reason: int | str | None = None
2062
2063
    # not part of the OpenAI spec but is useful for tracing the tokens
    # in agent scenarios
2064
    token_ids: list[int] | None = None
2065
2066


2067
class ChatCompletionResponse(OpenAIBaseModel):
2068
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
2069
    object: Literal["chat.completion"] = "chat.completion"
2070
2071
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
2072
    choices: list[ChatCompletionResponseChoice]
2073
2074
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
2075
    usage: UsageInfo
2076
2077

    # vLLM-specific fields that are not in OpenAI spec
2078
2079
2080
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None
    kv_transfer_params: dict[str, Any] | None = Field(
2081
2082
        default=None, description="KVTransfer parameters."
    )
2083
2084


2085
class DeltaMessage(OpenAIBaseModel):
2086
2087
2088
    role: str | None = None
    content: str | None = None
    reasoning_content: str | None = None
2089
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
2090
2091


2092
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
2093
2094
    index: int
    delta: DeltaMessage
2095
2096
2097
    logprobs: ChatCompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = None
2098
    # not part of the OpenAI spec but for tracing the tokens
2099
    token_ids: list[int] | None = None
2100
2101


2102
class ChatCompletionStreamResponse(OpenAIBaseModel):
2103
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
2104
    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
2105
2106
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
2107
    choices: list[ChatCompletionResponseStreamChoice]
2108
    usage: UsageInfo | None = Field(default=None)
2109
    # not part of the OpenAI spec but for tracing the tokens
2110
    prompt_token_ids: list[int] | None = None
2111
2112


2113
2114
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
2115
2116
    finish_reason: str | None = None
    stop_reason: int | str | None = None
2117
2118
2119
2120
2121
2122
2123
2124


class TranscriptionStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
    object: Literal["transcription.chunk"] = "transcription.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranscriptionResponseStreamChoice]
2125
    usage: UsageInfo | None = Field(default=None)
2126
2127


2128
2129
class InputTokensDetails(OpenAIBaseModel):
    cached_tokens: int
2130
2131
    input_tokens_per_turn: list[int] = Field(default_factory=list)
    cached_tokens_per_turn: list[int] = Field(default_factory=list)
2132
2133
2134


class OutputTokensDetails(OpenAIBaseModel):
2135
2136
    reasoning_tokens: int = 0
    tool_output_tokens: int = 0
2137
2138
    output_tokens_per_turn: list[int] = Field(default_factory=list)
    tool_output_tokens_per_turn: list[int] = Field(default_factory=list)
2139
2140
2141
2142
2143
2144
2145
2146


class ResponseUsage(OpenAIBaseModel):
    input_tokens: int
    input_tokens_details: InputTokensDetails
    output_tokens: int
    output_tokens_details: OutputTokensDetails
    total_tokens: int
2147
2148


2149
2150
2151
2152
2153
2154
def serialize_message(msg):
    """
    Serializes a single message
    """
    if isinstance(msg, dict):
        return msg
2155
    elif hasattr(msg, "to_dict"):
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
        return msg.to_dict()
    else:
        # fallback to pyandic dump
        return msg.model_dump_json()


def serialize_messages(msgs):
    """
    Serializes multiple messages
    """
    return [serialize_message(msg) for msg in msgs] if msgs else None


2169
2170
2171
2172
class ResponsesResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
    created_at: int = Field(default_factory=lambda: int(time.time()))
    # error: Optional[ResponseError] = None
2173
2174
2175
    incomplete_details: IncompleteDetails | None = None
    instructions: str | None = None
    metadata: Metadata | None = None
2176
2177
    model: str
    object: Literal["response"] = "response"
2178
    output: list[ResponseOutputItem]
2179
2180
2181
2182
2183
2184
2185
    parallel_tool_calls: bool
    temperature: float
    tool_choice: ToolChoice
    tools: list[Tool]
    top_p: float
    background: bool
    max_output_tokens: int
2186
2187
2188
2189
    max_tool_calls: int | None = None
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
2190
2191
    service_tier: Literal["auto", "default", "flex", "scale", "priority"]
    status: ResponseStatus
2192
2193
    text: ResponseTextConfig | None = None
    top_logprobs: int | None = None
2194
    truncation: Literal["auto", "disabled"]
2195
2196
    usage: ResponseUsage | None = None
    user: str | None = None
2197

2198
2199
2200
2201
    # --8<-- [start:responses-extra-params]
    # These are populated when enable_response_messages is set to True
    # NOTE: custom serialization is needed
    # see serialize_input_messages and serialize_output_messages
2202
2203
    input_messages: list[ChatCompletionMessageParam] | None = None
    output_messages: list[ChatCompletionMessageParam] | None = None
2204
2205
2206
2207
2208
2209
2210
    # --8<-- [end:responses-extra-params]

    # NOTE: openAI harmony doesn't serialize TextContent properly,
    # TODO: this fixes for TextContent, but need to verify for tools etc
    # https://github.com/openai/harmony/issues/78
    @field_serializer("output_messages", when_used="json")
    def serialize_output_messages(self, msgs, _info):
2211
        return serialize_messages(msgs)
2212
2213
2214
2215
2216

    # NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it
    # https://github.com/openai/harmony/issues/78
    @field_serializer("input_messages", when_used="json")
    def serialize_input_messages(self, msgs, _info):
2217
        return serialize_messages(msgs)
2218

2219
2220
2221
2222
2223
2224
2225
2226
2227
    @classmethod
    def from_request(
        cls,
        request: ResponsesRequest,
        sampling_params: SamplingParams,
        model_name: str,
        created_time: int,
        output: list[ResponseOutputItem],
        status: ResponseStatus,
2228
2229
2230
        usage: ResponseUsage | None = None,
        input_messages: list[ChatCompletionMessageParam] | None = None,
        output_messages: list[ChatCompletionMessageParam] | None = None,
2231
    ) -> "ResponsesResponse":
2232
        incomplete_details: IncompleteDetails | None = None
2233
2234
        if status == "incomplete":
            incomplete_details = IncompleteDetails(reason="max_output_tokens")
2235
2236
2237
        # TODO: implement the other reason for incomplete_details,
        # which is content_filter
        # incomplete_details = IncompleteDetails(reason='content_filter')
2238
2239
2240
        return cls(
            id=request.request_id,
            created_at=created_time,
2241
            incomplete_details=incomplete_details,
2242
2243
2244
2245
            instructions=request.instructions,
            metadata=request.metadata,
            model=model_name,
            output=output,
2246
2247
            input_messages=input_messages,
            output_messages=output_messages,
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
            parallel_tool_calls=request.parallel_tool_calls,
            temperature=sampling_params.temperature,
            tool_choice=request.tool_choice,
            tools=request.tools,
            top_p=sampling_params.top_p,
            background=request.background,
            max_output_tokens=sampling_params.max_tokens,
            max_tool_calls=request.max_tool_calls,
            previous_response_id=request.previous_response_id,
            prompt=request.prompt,
            reasoning=request.reasoning,
            service_tier=request.service_tier,
            status=status,
            text=request.text,
            top_logprobs=sampling_params.logprobs,
            truncation=request.truncation,
            user=request.user,
            usage=usage,
        )


2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartDoneEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.done"]
    """The type of the event. Always `response.reasoning_part.done`."""


# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartAddedEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.added"]
    """The type of the event. Always `response.reasoning_part.added`."""


2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
# vLLM Streaming Events
# Note: we override the response type with the vLLM ResponsesResponse type
class ResponseCompletedEvent(OpenAIResponseCompletedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseCreatedEvent(OpenAIResponseCreatedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseInProgressEvent(OpenAIResponseInProgressEvent):
    response: ResponsesResponse  # type: ignore[override]


2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
StreamingResponsesResponse: TypeAlias = (
    ResponseCreatedEvent
    | ResponseInProgressEvent
    | ResponseCompletedEvent
    | ResponseOutputItemAddedEvent
    | ResponseOutputItemDoneEvent
    | ResponseContentPartAddedEvent
    | ResponseContentPartDoneEvent
    | ResponseReasoningTextDeltaEvent
    | ResponseReasoningTextDoneEvent
    | ResponseReasoningPartAddedEvent
    | ResponseReasoningPartDoneEvent
    | ResponseCodeInterpreterCallInProgressEvent
    | ResponseCodeInterpreterCallCodeDeltaEvent
    | ResponseWebSearchCallInProgressEvent
    | ResponseWebSearchCallSearchingEvent
    | ResponseWebSearchCallCompletedEvent
    | ResponseCodeInterpreterCallCodeDoneEvent
    | ResponseCodeInterpreterCallInterpretingEvent
    | ResponseCodeInterpreterCallCompletedEvent
)
2348

2349
2350
2351
BatchRequestInputBody: TypeAlias = (
    ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest
)
2352
2353


2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
class BatchRequestInput(OpenAIBaseModel):
    """
    The per-line object of the batch input file.

    NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
    """

    # A developer-provided per-request id that will be used to match outputs to
    # inputs. Must be unique for each request in a batch.
    custom_id: str

    # The HTTP method to be used for the request. Currently only POST is
    # supported.
    method: str

    # The OpenAI API relative URL to be used for the request. Currently
    # /v1/chat/completions is supported.
    url: str

2373
    # The parameters of the request.
2374
    body: BatchRequestInputBody
2375

2376
    @field_validator("body", mode="plain")
2377
2378
2379
    @classmethod
    def check_type_for_url(cls, value: Any, info: ValidationInfo):
        # Use url to disambiguate models
2380
        url: str = info.data["url"]
2381
2382
2383
2384
        if url == "/v1/chat/completions":
            return ChatCompletionRequest.model_validate(value)
        if url == "/v1/embeddings":
            return TypeAdapter(EmbeddingRequest).validate_python(value)
2385
        if url.endswith("/score"):
2386
            return ScoreRequest.model_validate(value)
2387
2388
2389
        if url.endswith("/rerank"):
            return RerankRequest.model_validate(value)
        return TypeAdapter(BatchRequestInputBody).validate_python(value)
2390

2391

2392
2393
2394
2395
2396
2397
2398
2399
class BatchResponseData(OpenAIBaseModel):
    # HTTP status code of the response.
    status_code: int = 200

    # An unique identifier for the API request.
    request_id: str

    # The body of the response.
2400
2401
2402
2403
2404
2405
2406
    body: (
        ChatCompletionResponse
        | EmbeddingResponse
        | ScoreResponse
        | RerankResponse
        | None
    ) = None
2407
2408


2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
class BatchRequestOutput(OpenAIBaseModel):
    """
    The per-line object of the batch output and error files
    """

    id: str

    # A developer-provided per-request id that will be used to match outputs to
    # inputs.
    custom_id: str

2420
    response: BatchResponseData | None
2421
2422
2423

    # For requests that failed with a non-HTTP error, this will contain more
    # information on the cause of the failure.
2424
    error: Any | None
2425
2426


2427
class TokenizeCompletionRequest(OpenAIBaseModel):
2428
    model: str | None = None
2429
2430
    prompt: str

2431
2432
2433
2434
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
2435
2436
            "the prompt."
        ),
2437
    )
2438
    return_token_strs: bool | None = Field(
2439
        default=False,
2440
2441
2442
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
2443
    )
2444
2445
2446


class TokenizeChatRequest(OpenAIBaseModel):
2447
    model: str | None = None
2448
    messages: list[ChatCompletionMessageParam]
2449

2450
2451
    add_generation_prompt: bool = Field(
        default=True,
2452
2453
2454
2455
2456
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
2457
    )
2458
    return_token_strs: bool | None = Field(
2459
        default=False,
2460
2461
2462
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
2463
    )
2464
2465
    continue_final_message: bool = Field(
        default=False,
2466
2467
2468
2469
2470
2471
2472
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
2473
2474
2475
2476
2477
2478
2479
2480
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
2481
2482
            "default)."
        ),
2483
    )
2484
    chat_template: str | None = Field(
2485
2486
2487
2488
2489
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
2490
2491
            "does not define one."
        ),
2492
    )
2493
    chat_template_kwargs: dict[str, Any] | None = Field(
2494
        default=None,
2495
2496
        description=(
            "Additional keyword args to pass to the template renderer. "
2497
2498
            "Will be accessible by the chat template."
        ),
2499
    )
2500
    mm_processor_kwargs: dict[str, Any] | None = Field(
2501
2502
2503
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
2504
    tools: list[ChatCompletionToolsParam] | None = Field(
2505
2506
2507
        default=None,
        description=("A list of tools the model may call."),
    )
2508

2509
2510
2511
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
2512
2513
2514
2515
2516
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
2517
2518
        return data

2519

2520
TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest
2521
2522
2523
2524
2525


class TokenizeResponse(OpenAIBaseModel):
    count: int
    max_model_len: int
2526
    tokens: list[int]
2527
    token_strs: list[str] | None = None
2528
2529
2530


class DetokenizeRequest(OpenAIBaseModel):
2531
    model: str | None = None
2532
    tokens: list[int]
2533
2534
2535
2536


class DetokenizeResponse(OpenAIBaseModel):
    prompt: str
2537
2538


2539
2540
class TokenizerInfoResponse(OpenAIBaseModel):
    """
2541
    Response containing tokenizer configuration
2542
2543
2544
2545
2546
2547
2548
    equivalent to tokenizer_config.json
    """

    model_config = ConfigDict(extra="allow")
    tokenizer_class: str


2549
class LoadLoRAAdapterRequest(BaseModel):
2550
2551
2552
2553
    lora_name: str
    lora_path: str


2554
class UnloadLoRAAdapterRequest(BaseModel):
2555
    lora_name: str
2556
    lora_int_id: int | None = Field(default=None)
2557
2558
2559


## Protocols for Audio
2560
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
2561
2562
2563
2564


class TranscriptionRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
2565
    # https://platform.openai.com/docs/api-reference/audio/createTranscription
2566
2567
2568
2569
2570
2571
2572

    file: UploadFile
    """
    The audio file object (not file name) to transcribe, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2573
    model: str | None = None
2574
2575
2576
    """ID of the model to use.
    """

2577
    language: str | None = None
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
    """The language of the input audio.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy and latency.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    ## TODO (varun) : Support if set to 0, certain thresholds are met !!

2601
    timestamp_granularities: list[Literal["word", "segment"]] = Field(
2602
2603
        alias="timestamp_granularities[]", default=[]
    )
2604
2605
2606
2607
2608
2609
2610
2611
    """The timestamp granularities to populate for this transcription.

    `response_format` must be set `verbose_json` to use timestamp granularities.
    Either or both of these options are supported: `word`, or `segment`. Note:
    There is no additional latency for segment timestamps, but generating word
    timestamps incurs additional latency.
    """

2612
    stream: bool | None = False
2613
    """When set, it will enable output to be streamed in a similar fashion
2614
    as the Chat Completion endpoint.
2615
    """
2616
    # --8<-- [start:transcription-extra-params]
2617
    # Flattened stream option to simplify form data.
2618
2619
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
2620

2621
    vllm_xargs: dict[str, str | int | float] | None = Field(
2622
        default=None,
2623
2624
2625
2626
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
2627
    )
2628
    # --8<-- [end:transcription-extra-params]
2629

2630
    to_language: str | None = None
2631
2632
    """The language of the output audio we transcribe to.

2633
    Please note that this is not currently used by supported models at this
2634
2635
2636
    time, but it is a placeholder for future use, matching translation api.
    """

2637
    # --8<-- [start:transcription-sampling-params]
2638
2639
2640
2641
2642
2643
2644
2645
2646
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """

2647
    top_p: float | None = None
2648
    """Enables nucleus (top-p) sampling, where tokens are selected from the
2649
2650
2651
    smallest possible set whose cumulative probability exceeds `p`.
    """

2652
    top_k: int | None = None
2653
2654
    """Limits sampling to the `k` most probable tokens at each step."""

2655
    min_p: float | None = None
2656
    """Filters out tokens with a probability lower than `min_p`, ensuring a
2657
2658
2659
    minimum likelihood threshold during sampling.
    """

2660
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2661
2662
    """The seed to use for sampling."""

2663
    frequency_penalty: float | None = 0.0
2664
2665
    """The frequency penalty to use for sampling."""

2666
    repetition_penalty: float | None = None
2667
2668
    """The repetition penalty to use for sampling."""

2669
    presence_penalty: float | None = 0.0
2670
    """The presence penalty to use for sampling."""
2671
    # --8<-- [end:transcription-sampling-params]
2672

2673
2674
    # Default sampling parameters for transcription requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
2675
2676
2677
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
2678
        "top_k": 0,
2679
        "min_p": 0.0,
2680
2681
2682
    }

    def to_sampling_params(
2683
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2684
    ) -> SamplingParams:
2685
2686
2687
2688
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
2689

2690
2691
2692
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2693
2694
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2695
2696
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
2697
2698
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
2699
2700
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
2701
2702
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
2703
2704
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
2705
2706
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
2707
2708
2709
2710

        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )

        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=repetition_penalty,
            presence_penalty=self.presence_penalty,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
            extra_args=self.vllm_xargs,
        )
2729
2730
2731

    @model_validator(mode="before")
    @classmethod
2732
2733
2734
2735
2736
2737
2738
    def validate_transcription_request(cls, data):
        if isinstance(data.get("file"), str):
            raise HTTPException(
                status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
                detail="Expected 'file' to be a file-like object, not 'str'.",
            )

2739
2740
2741
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2742
            raise ValueError("Stream options can only be defined when `stream=True`.")
2743
2744

        return data
2745
2746
2747


# Transcription response objects
2748
2749
2750
2751
2752
class TranscriptionUsageAudio(OpenAIBaseModel):
    type: Literal["duration"] = "duration"
    seconds: int


2753
2754
2755
class TranscriptionResponse(OpenAIBaseModel):
    text: str
    """The transcribed text."""
2756
    usage: TranscriptionUsageAudio
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807


class TranscriptionWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranscriptionSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

2808
    tokens: list[int]
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
    """Array of token IDs for the text content."""


class TranscriptionResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The transcribed text."""

2822
    segments: list[TranscriptionSegment] | None = None
2823
2824
    """Segments of the transcribed text and their corresponding details."""

2825
    words: list[TranscriptionWord] | None = None
2826
    """Extracted words and their corresponding timestamps."""
2827
2828
2829
2830


class TranslationResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
2831
2832
    finish_reason: str | None = None
    stop_reason: int | str | None = None
2833
2834
2835
2836
2837
2838
2839
2840


class TranslationStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
    object: Literal["translation.chunk"] = "translation.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranslationResponseStreamChoice]
2841
    usage: UsageInfo | None = Field(default=None)
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853


class TranslationRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/audio/createTranslation

    file: UploadFile
    """
    The audio file object (not file name) to translate, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2854
    model: str | None = None
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
    """ID of the model to use.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    # TODO support additional sampling parameters
    # --8<-- [start:translation-sampling-params]
2874
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2875
2876
    """The seed to use for sampling."""

2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """
    # --8<-- [end:translation-sampling-params]

    # --8<-- [start:translation-extra-params]
2888
    language: str | None = None
2889
2890
2891
2892
2893
2894
2895
    """The language of the input audio we translate from.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy.
    """

2896
    to_language: str | None = None
2897
2898
2899
2900
2901
2902
2903
    """The language of the input audio we translate to.

    Please note that this is not supported by all models, refer to the specific
    model documentation for more details.
    For instance, Whisper only supports `to_language=en`.
    """

2904
    stream: bool | None = False
2905
    """Custom field not present in the original OpenAI definition. When set,
2906
    it will enable output to be streamed in a similar fashion as the Chat
2907
    Completion endpoint.
2908
2909
    """
    # Flattened stream option to simplify form data.
2910
2911
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
2912
2913
2914
2915
2916
2917
2918
2919
    # --8<-- [end:translation-extra-params]

    # Default sampling parameters for translation requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "temperature": 0,
    }

    def to_sampling_params(
2920
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2921
    ) -> SamplingParams:
2922
2923
2924
2925
2926
2927
2928
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2929
2930
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2931

2932
2933
2934
2935
2936
2937
2938
2939
        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
        )
2940
2941
2942
2943
2944
2945
2946

    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2947
            raise ValueError("Stream options can only be defined when `stream=True`.")
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020

        return data


# Translation response objects
class TranslationResponse(OpenAIBaseModel):
    text: str
    """The translated text."""


class TranslationWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranslationSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

    tokens: list[int]
    """Array of token IDs for the text content."""


class TranslationResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The translated text."""

3021
    segments: list[TranslationSegment] | None = None
3022
3023
    """Segments of the translated text and their corresponding details."""

3024
    words: list[TranslationWord] | None = None
3025
    """Extracted words and their corresponding timestamps."""