protocol.py 104 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
6
import json
Zhuohan Li's avatar
Zhuohan Li committed
7
import time
8
from http import HTTPStatus
9
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar
Zhuohan Li's avatar
Zhuohan Li committed
10

11
import regex as re
12
import torch
13
from fastapi import HTTPException, UploadFile
14
from openai.types.chat.chat_completion_audio import (
15
16
17
    ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
18
19
20
21
22
from openai.types.responses import (
    ResponseCodeInterpreterCallCodeDeltaEvent,
    ResponseCodeInterpreterCallCodeDoneEvent,
    ResponseCodeInterpreterCallCompletedEvent,
    ResponseCodeInterpreterCallInProgressEvent,
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    ResponseCodeInterpreterCallInterpretingEvent,
    ResponseContentPartAddedEvent,
    ResponseContentPartDoneEvent,
    ResponseFunctionToolCall,
    ResponseInputItemParam,
    ResponseOutputItem,
    ResponseOutputItemAddedEvent,
    ResponseOutputItemDoneEvent,
    ResponsePrompt,
    ResponseReasoningItem,
    ResponseReasoningTextDeltaEvent,
    ResponseReasoningTextDoneEvent,
    ResponseStatus,
    ResponseWebSearchCallCompletedEvent,
    ResponseWebSearchCallInProgressEvent,
    ResponseWebSearchCallSearchingEvent,
)
40
from openai.types.responses import (
41
42
43
    ResponseCompletedEvent as OpenAIResponseCompletedEvent,
)
from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent
44
from openai.types.responses import (
45
46
    ResponseInProgressEvent as OpenAIResponseInProgressEvent,
)
47
from openai.types.responses.response_reasoning_item import (
48
49
    Content as ResponseReasoningTextContent,
)
50
51
52
53
54

# Backward compatibility for OpenAI client versions
try:  # For older openai versions (< 1.100.0)
    from openai.types.responses import ResponseTextConfig
except ImportError:  # For newer openai versions (>= 1.100.0)
55
    from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
56

57

58
from openai.types.responses.response import IncompleteDetails, ToolChoice
59
60
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
61
62
63
64
65
66
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    TypeAdapter,
    ValidationInfo,
67
    field_serializer,
68
69
70
    field_validator,
    model_validator,
)
Zhuohan Li's avatar
Zhuohan Li committed
71

72
from vllm import envs
73
74
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam
75
from vllm.logger import init_logger
76
from vllm.logprobs import Logprob
77
from vllm.pooling_params import PoolingParams
78
79
80
81
82
83
from vllm.sampling_params import (
    BeamSearchParams,
    RequestOutputKind,
    SamplingParams,
    StructuredOutputsParams,
)
84
85
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
86

87
88
89
90
91
92
93
94
95
96
97
98
EMBED_DTYPE_TO_TORCH_DTYPE = {
    "float32": torch.float32,
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    # I'm not sure if other platforms' CPUs support the fp8 data format.
    # EMBED_DTYPE only uses the fp8 data representation,
    # does not use fp8 computation, and only occurs on the CPU.
    # Apologize for any possible break.
    "fp8_e4m3": torch.float8_e4m3fn,
    "fp8_e5m2": torch.float8_e5m2,
}

99
100
logger = init_logger(__name__)

101
_LONG_INFO = torch.iinfo(torch.long)
102

Zhuohan Li's avatar
Zhuohan Li committed
103

104
class OpenAIBaseModel(BaseModel):
105
106
107
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

108
    # Cache class field names
109
    field_names: ClassVar[set[str] | None] = None
110

111
    @model_validator(mode="wrap")
112
    @classmethod
113
114
115
116
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
117
118
        field_names = cls.field_names
        if field_names is None:
119
120
121
122
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
123
                if alias := getattr(field, "alias", None):
124
125
126
127
128
129
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
130
                "The following fields were present in the request but ignored: %s",
131
132
                data.keys() - field_names,
            )
133
        return result
134
135


136
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
137
138
    message: str
    type: str
139
    param: str | None = None
140
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
141
142


143
144
145
146
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


147
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
148
149
150
151
152
153
154
155
156
157
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
158
    group: str | None = None
159
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
160
161


162
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
163
164
165
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
166
    owned_by: str = "vllm"
167
168
169
    root: str | None = None
    parent: str | None = None
    max_model_len: int | None = None
170
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
171
172


173
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
174
    object: str = "list"
175
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
176
177


178
class PromptTokenUsageInfo(OpenAIBaseModel):
179
    cached_tokens: int | None = None
180
181


182
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
183
184
    prompt_tokens: int = 0
    total_tokens: int = 0
185
186
    completion_tokens: int | None = 0
    prompt_tokens_details: PromptTokenUsageInfo | None = None
Zhuohan Li's avatar
Zhuohan Li committed
187
188


189
190
class RequestResponseMetadata(BaseModel):
    request_id: str
191
    final_usage_info: UsageInfo | None = None
192
193


194
195
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
196
    description: str | None = None
197
198
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
199
200
    json_schema: dict[str, Any] | None = Field(default=None, alias="schema")
    strict: bool | None = None
201
202


203
204
205
206
class StructuralTag(OpenAIBaseModel):
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
207
    structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema")
208
209
210
211
212
213
214
215
216
    end: str


class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    structures: list[StructuralTag]
    triggers: list[str]


217
class ResponseFormat(OpenAIBaseModel):
218
    # type must be "json_schema", "json_object", or "text"
219
    type: Literal["text", "json_object", "json_schema"]
220
    json_schema: JsonSchemaResponseFormat | None = None
221
222


223
AnyResponseFormat: TypeAlias = ResponseFormat | StructuralTagResponseFormat
224
225


226
class StreamOptions(OpenAIBaseModel):
227
228
    include_usage: bool | None = True
    continuous_usage_stats: bool | None = False
229
230


231
232
class FunctionDefinition(OpenAIBaseModel):
    name: str
233
234
    description: str | None = None
    parameters: dict[str, Any] | None = None
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250


class ChatCompletionToolsParam(OpenAIBaseModel):
    type: Literal["function"] = "function"
    function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
    name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
    function: ChatCompletionNamedFunction
    type: Literal["function"] = "function"


251
252
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
253
254
class LogitsProcessorConstructor(BaseModel):
    qualname: str
255
256
    args: list[Any] | None = None
    kwargs: dict[str, Any] | None = None
257

258
259
    model_config = ConfigDict(extra="forbid")

260

261
LogitsProcessors = list[str | LogitsProcessorConstructor]
262
263


264
def get_logits_processors(
265
266
    processors: LogitsProcessors | None, pattern: str | None
) -> list[Any] | None:
267
268
269
    if processors and pattern:
        logits_processors = []
        for processor in processors:
270
            qualname = processor if isinstance(processor, str) else processor.qualname
271
272
273
274
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
275
276
                    "for more information."
                )
277
278
279
280
281
282
283
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
284
285
286
                logits_processor = logits_processor(
                    *processor.args or [], **processor.kwargs or {}
                )
287
288
289
290
291
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
292
            "server. See --logits-processor-pattern engine argument "
293
294
            "for more information."
        )
295
296
297
    return None


298
299
300
ResponseInputOutputItem: TypeAlias = (
    ResponseInputItemParam | ResponseReasoningItem | ResponseFunctionToolCall
)
301
302


303
304
305
class ResponsesRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/responses/create
306
307
    background: bool | None = False
    include: (
308
309
310
311
312
313
314
315
316
317
        list[
            Literal[
                "code_interpreter_call.outputs",
                "computer_call_output.output.image_url",
                "file_search_call.results",
                "message.input_image.image_url",
                "message.output_text.logprobs",
                "reasoning.encrypted_content",
            ],
        ]
318
319
320
321
322
323
324
325
326
327
328
329
        | None
    ) = None
    input: str | list[ResponseInputOutputItem]
    instructions: str | None = None
    max_output_tokens: int | None = None
    max_tool_calls: int | None = None
    metadata: Metadata | None = None
    model: str | None = None
    parallel_tool_calls: bool | None = True
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
330
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
331
332
333
334
    store: bool | None = True
    stream: bool | None = False
    temperature: float | None = None
    text: ResponseTextConfig | None = None
335
336
    tool_choice: ToolChoice = "auto"
    tools: list[Tool] = Field(default_factory=list)
337
338
339
340
    top_logprobs: int | None = 0
    top_p: float | None = None
    truncation: Literal["auto", "disabled"] | None = "disabled"
    user: str | None = None
341
342
343
344
345
346
347

    # --8<-- [start:responses-extra-params]
    request_id: str = Field(
        default_factory=lambda: f"resp_{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
348
349
            "through out the inference process and return in response."
        ),
350
    )
351
    mm_processor_kwargs: dict[str, Any] | None = Field(
352
353
354
355
356
357
358
359
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
360
361
            "if the served model does not use priority scheduling."
        ),
362
    )
363
    cache_salt: str | None = Field(
364
365
366
367
368
369
370
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
371
372
373
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
374
375
376
377
378
379

    enable_response_messages: bool = Field(
        default=False,
        description=(
            "Dictates whether or not to return messages as part of the "
            "response object. Currently only supported for non-streaming "
380
381
382
            "non-background and gpt-oss only. "
        ),
    )
383
384
385
386
387
388
389
390
391
392
    # --8<-- [end:responses-extra-params]

    _DEFAULT_SAMPLING_PARAMS = {
        "temperature": 1.0,
        "top_p": 1.0,
    }

    def to_sampling_params(
        self,
        default_max_tokens: int,
393
        default_sampling_params: dict | None = None,
394
395
396
397
398
399
400
401
402
    ) -> SamplingParams:
        if self.max_output_tokens is None:
            max_tokens = default_max_tokens
        else:
            max_tokens = min(self.max_output_tokens, default_max_tokens)

        default_sampling_params = default_sampling_params or {}
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
403
404
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
405
406
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
407
408
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
409
        stop_token_ids = default_sampling_params.get("stop_token_ids")
410
411

        # Structured output
412
        structured_outputs = None
413
414
        if self.text is not None and self.text.format is not None:
            response_format = self.text.format
415
416
417
418
            if (
                response_format.type == "json_schema"
                and response_format.schema_ is not None
            ):
419
                structured_outputs = StructuredOutputsParams(
420
421
                    json=response_format.schema_
                )
422
423
424
425
426
427
428
429
            elif response_format.type == "json_object":
                raise NotImplementedError("json_object is not supported")

        # TODO: add more parameters
        return SamplingParams.from_optional(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
430
            logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
431
            stop_token_ids=stop_token_ids,
432
433
434
            output_kind=(
                RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
            ),
435
            structured_outputs=structured_outputs,
436
437
        )

438
439
440
441
    def is_include_output_logprobs(self) -> bool:
        """Check if the request includes output logprobs."""
        if self.include is None:
            return False
442
443
444
445
        return (
            isinstance(self.include, list)
            and "message.output_text.logprobs" in self.include
        )
446

447
448
449
450
451
    @model_validator(mode="before")
    def validate_background(cls, data):
        if not data.get("background"):
            return data
        if not data.get("store", True):
452
            raise ValueError("background can only be used when `store` is true")
453
454
455
456
457
458
459
460
        return data

    @model_validator(mode="before")
    def validate_prompt(cls, data):
        if data.get("prompt") is not None:
            raise ValueError("prompt template is not supported")
        return data

461
462
463
464
465
466
    @model_validator(mode="before")
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
467
468
469
470
471
472
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
473
474
        return data

475

476
class ChatCompletionRequest(OpenAIBaseModel):
477
478
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
479
    messages: list[ChatCompletionMessageParam]
480
481
482
483
484
485
    model: str | None = None
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: bool | None = False
    top_logprobs: int | None = 0
    max_tokens: int | None = Field(
486
        default=None,
487
488
        deprecated="max_tokens is deprecated in favor of "
        "the max_completion_tokens field",
489
    )
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    max_completion_tokens: int | None = None
    n: int | None = 1
    presence_penalty: float | None = 0.0
    response_format: AnyResponseFormat | None = None
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    temperature: float | None = None
    top_p: float | None = None
    tools: list[ChatCompletionToolsParam] | None = None
    tool_choice: (
        Literal["none"]
        | Literal["auto"]
        | Literal["required"]
        | ChatCompletionNamedToolChoiceParam
        | None
    ) = "none"
    reasoning_effort: Literal["low", "medium", "high"] | None = None
509
    include_reasoning: bool = True
510

511
    # NOTE this will be ignored by vLLM -- the model determines the behavior
512
513
    parallel_tool_calls: bool | None = False
    user: str | None = None
514

515
    # --8<-- [start:chat-completion-sampling-params]
516
    best_of: int | None = None
517
    use_beam_search: bool = False
518
519
520
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
521
    length_penalty: float = 1.0
522
    stop_token_ids: list[int] | None = []
523
524
525
526
527
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
528
529
530
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    prompt_logprobs: int | None = None
    allowed_token_ids: list[int] | None = None
531
    bad_words: list[str] = Field(default_factory=list)
532
    # --8<-- [end:chat-completion-sampling-params]
533

534
    # --8<-- [start:chat-completion-extra-params]
535
    echo: bool = Field(
536
537
538
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
539
540
            "if they belong to the same role."
        ),
541
    )
542
    add_generation_prompt: bool = Field(
543
        default=True,
544
545
546
547
548
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
549
    )
550
551
    continue_final_message: bool = Field(
        default=False,
552
553
554
555
556
557
558
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
559
    )
560
    add_special_tokens: bool = Field(
561
562
563
564
565
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
566
            "special tokens so this should be set to false (as is the "
567
568
            "default)."
        ),
569
    )
570
    documents: list[dict[str, str]] | None = Field(
571
        default=None,
572
573
574
575
576
577
578
        description=(
            "A list of dicts representing documents that will be accessible to "
            "the model if it is performing RAG (retrieval-augmented generation)."
            " If the template does not support RAG, this argument will have no "
            "effect. We recommend that each document should be a dict containing "
            '"title" and "text" keys.'
        ),
579
    )
580
    chat_template: str | None = Field(
581
582
583
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
584
585
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
586
587
            "does not define one."
        ),
588
    )
589
    chat_template_kwargs: dict[str, Any] | None = Field(
590
        default=None,
591
592
        description=(
            "Additional keyword args to pass to the template renderer. "
593
594
            "Will be accessible by the chat template."
        ),
595
    )
596
    mm_processor_kwargs: dict[str, Any] | None = Field(
597
598
599
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
600
    structured_outputs: StructuredOutputsParams | None = Field(
601
        default=None,
602
        description="Additional kwargs for structured outputs",
603
    )
604
    guided_json: str | dict | BaseModel | None = Field(
605
606
607
608
        default=None,
        description=(
            "`guided_json` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
609
610
            "Please pass `json` to `structured_outputs` instead."
        ),
611
    )
612
    guided_regex: str | None = Field(
613
614
615
616
        default=None,
        description=(
            "`guided_regex` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
617
618
            "Please pass `regex` to `structured_outputs` instead."
        ),
619
    )
620
    guided_choice: list[str] | None = Field(
621
622
623
624
        default=None,
        description=(
            "`guided_choice` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
625
626
            "Please pass `choice` to `structured_outputs` instead."
        ),
627
    )
628
    guided_grammar: str | None = Field(
629
630
631
632
        default=None,
        description=(
            "`guided_grammar` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
633
634
            "Please pass `grammar` to `structured_outputs` instead."
        ),
635
    )
636
    structural_tag: str | None = Field(
637
638
639
640
        default=None,
        description=(
            "`structural_tag` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
641
642
            "Please pass `structural_tag` to `structured_outputs` instead."
        ),
643
    )
644
    guided_decoding_backend: str | None = Field(
645
646
647
648
        default=None,
        description=(
            "`guided_decoding_backend` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
649
650
            "Please remove it from your request."
        ),
651
    )
652
    guided_whitespace_pattern: str | None = Field(
653
654
655
656
657
658
659
        default=None,
        description=(
            "`guided_whitespace_pattern` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
            "Please pass `whitespace_pattern` to `structured_outputs` instead."
        ),
    )
660
661
662
663
664
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
665
666
            "if the served model does not use priority scheduling."
        ),
667
    )
668
669
670
671
672
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
673
674
            "through out the inference process and return in response."
        ),
675
    )
676
    logits_processors: LogitsProcessors | None = Field(
677
678
679
680
681
682
683
684
685
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
686
687
688
            "{'param': 'value'}}."
        ),
    )
689
    return_tokens_as_token_ids: bool | None = Field(
690
691
692
693
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
694
695
696
            "that are not JSON-encodable can be identified."
        ),
    )
697
    return_token_ids: bool | None = Field(
698
699
700
701
702
703
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
704
705
706
            "need to map generated text back to input tokens."
        ),
    )
707
    cache_salt: str | None = Field(
708
709
710
711
712
713
714
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
715
716
717
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
718
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
719
        default=None,
720
721
        description="KVTransfer parameters used for disaggregated serving.",
    )
722

723
    vllm_xargs: dict[str, str | int | float] | None = Field(
724
        default=None,
725
726
727
728
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
729
730
    )

731
    # --8<-- [end:chat-completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
732

733
734
735
736
737
    # Default sampling parameters for chat completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
738
        "top_k": 0,
739
740
741
742
        "min_p": 0.0,
    }

    def to_beam_search_params(
743
744
        self, max_tokens: int, default_sampling_params: dict
    ) -> BeamSearchParams:
745
        n = self.n if self.n is not None else 1
746
747
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
748
749
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
750
751
752
753
754
755

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
756
            length_penalty=self.length_penalty,
757
758
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
759

760
    def to_sampling_params(
761
        self,
762
        max_tokens: int,
763
        logits_processor_pattern: str | None,
764
        default_sampling_params: dict,
765
    ) -> SamplingParams:
766
767
768
769
770
771
772
773
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
774
775
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
776
777
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
778
779
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
780
781
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
782
783
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
784
785
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
786
787
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
788

789
790
791
792
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.top_logprobs

793
794
795
796
797
798
799
800
801
802
803
804
805
806
        # Forward deprecated guided_* parameters to structured_outputs
        if self.structured_outputs is None:
            kwargs = dict[str, Any](
                json=self.guided_json,
                regex=self.guided_regex,
                choice=self.guided_choice,
                grammar=self.guided_grammar,
                whitespace_pattern=self.guided_whitespace_pattern,
                structural_tag=self.structural_tag,
            )
            kwargs = {k: v for k, v in kwargs.items() if v is not None}
            if len(kwargs) > 0:
                self.structured_outputs = StructuredOutputsParams(**kwargs)

807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
        response_format = self.response_format
        json_schema_from_tool = self._get_json_schema_from_tool()
        if response_format is not None or json_schema_from_tool is not None:
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
            if response_format is not None:
                if response_format.type == "json_object":
                    self.structured_outputs.json_object = True
                elif response_format.type == "json_schema":
                    json_schema = response_format.json_schema
                    assert json_schema is not None
                    self.structured_outputs.json = json_schema.json_schema
                elif response_format.type == "structural_tag":
                    structural_tag = response_format
                    assert structural_tag is not None and isinstance(
826
827
                        structural_tag, StructuralTagResponseFormat
                    )
828
                    s_tag_obj = structural_tag.model_dump(by_alias=True)
829
                    self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
830
831
832
833

            # Set structured output params for tool calling
            if json_schema_from_tool is not None:
                self.structured_outputs.json = json_schema_from_tool
834

835
836
837
838
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
839
        return SamplingParams.from_optional(
840
            n=self.n,
841
            best_of=self.best_of,
842
843
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
844
845
846
847
848
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
849
            seed=self.seed,
850
851
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
852
            logprobs=self.top_logprobs if self.logprobs else None,
853
            prompt_logprobs=prompt_logprobs,
854
            ignore_eos=self.ignore_eos,
855
            max_tokens=max_tokens,
856
            min_tokens=self.min_tokens,
857
858
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
859
860
861
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
862
            include_stop_str_in_output=self.include_stop_str_in_output,
863
            truncate_prompt_tokens=self.truncate_prompt_tokens,
864
865
866
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
867
            structured_outputs=self.structured_outputs,
Robert Shaw's avatar
Robert Shaw committed
868
            logit_bias=self.logit_bias,
869
            bad_words=self.bad_words,
870
            allowed_token_ids=self.allowed_token_ids,
871
872
            extra_args=extra_args or None,
        )
873

874
    def _get_json_schema_from_tool(self) -> str | dict | None:
875
876
877
878
879
880
881
882
883
        # user has chosen to not use any tool
        if self.tool_choice == "none" or self.tools is None:
            return None

        # user has chosen to use a named tool
        if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
            tool_name = self.tool_choice.function.name
            tools = {tool.function.name: tool.function for tool in self.tools}
            if tool_name not in tools:
884
                raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
885
886
887
            tool = tools[tool_name]
            return tool.parameters

888
889
890
891
892
893
894
895
        if self.tool_choice == "required":
            # Pydantic schema generation cannot be used since the JSON schema
            # has to be constructed for a specific instantiation of a tool list
            # so that parameters of a function are correctly generated
            # based on the chosen function name
            def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
                return {
                    "properties": {
896
                        "name": {"type": "string", "enum": [tool.function.name]},
897
898
899
900
901
                        # parameters are always generated as '{}' in the final
                        # output if they are missing from the request
                        # (i.e. are None or '{}') so the schema is
                        # updated to produce an empty object in that case
                        "parameters": tool.function.parameters
902
903
                        if tool.function.parameters
                        else {"type": "object", "properties": {}},
904
                    },
905
                    "required": ["name", "parameters"],
906
907
                }

908
            def get_tool_schema_defs(tools: list[ChatCompletionToolsParam]) -> dict:
909
910
911
912
913
914
                all_defs = dict[str, dict[str, Any]]()
                for tool in tools:
                    if tool.function.parameters is None:
                        continue
                    defs = tool.function.parameters.pop("$defs", {})
                    for def_name, def_schema in defs.items():
915
                        if def_name in all_defs and all_defs[def_name] != def_schema:
916
917
918
                            raise ValueError(
                                f"Tool definition '{def_name}' has "
                                "multiple schemas, which is not "
919
920
                                "supported."
                            )
921
922
923
924
                        else:
                            all_defs[def_name] = def_schema
                return all_defs

925
926
927
928
929
            json_schema = {
                "type": "array",
                "minItems": 1,
                "items": {
                    "type": "object",
930
931
                    "anyOf": [get_tool_schema(tool) for tool in self.tools],
                },
932
            }
933
934
935
            json_schema_defs = get_tool_schema_defs(self.tools)
            if json_schema_defs:
                json_schema["$defs"] = json_schema_defs
936
937
            return json_schema

938
        return None
939

940
    @model_validator(mode="before")
941
    @classmethod
942
943
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
944
            raise ValueError("Stream options can only be defined when `stream=True`.")
945
946
947
948
949
950
951

        return data

    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
952
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
953
                raise ValueError(
954
955
                    "`prompt_logprobs` are not available when `stream=True`."
                )
956

957
            if prompt_logprobs < 0 and prompt_logprobs != -1:
958
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
959
            if prompt_logprobs == -1 and not envs.VLLM_USE_V1:
960
961
962
                raise ValueError(
                    "`prompt_logprobs=-1` is only supported with vLLM engine V1."
                )
963
        if (top_logprobs := data.get("top_logprobs")) is not None:
964
            if top_logprobs < 0 and top_logprobs != -1:
965
                raise ValueError("`top_logprobs` must be a positive value or -1.")
966

967
            if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
968
969
970
971
972
                raise ValueError(
                    "when using `top_logprobs`, `logprobs` must be set to true."
                )

        return data
973

974
975
    @model_validator(mode="before")
    @classmethod
976
    def check_structured_outputs_count(cls, data):
977
978
979
        if isinstance(data, ValueError):
            raise data

980
        if data.get("structured_outputs", None) is None:
981
982
            return data

983
        structured_outputs_kwargs = data["structured_outputs"]
984
985
        count = sum(
            structured_outputs_kwargs.get(k) is not None
986
987
            for k in ("json", "regex", "choice")
        )
988
989
        # you can only use one kind of constraints for structured outputs
        if count > 1:
990
            raise ValueError(
991
                "You can only use one kind of constraints for structured "
992
993
                "outputs ('json', 'regex' or 'choice')."
            )
994
995
        # you can only either use structured outputs or tools, not both
        if count > 1 and data.get("tool_choice", "none") not in (
996
997
998
            "none",
            "auto",
            "required",
999
        ):
1000
            raise ValueError(
1001
                "You can only either use constraints for structured outputs "
1002
1003
                "or tools, not both."
            )
1004
1005
1006
1007
        return data

    @model_validator(mode="before")
    @classmethod
1008
1009
1010
    def check_tool_usage(cls, data):
        # if "tool_choice" is not specified but tools are provided,
        # default to "auto" tool_choice
1011
        if "tool_choice" not in data and data.get("tools"):
1012
1013
            data["tool_choice"] = "auto"

1014
        # if "tool_choice" is "none" -- no validation is needed for tools
1015
1016
1017
        if "tool_choice" in data and data["tool_choice"] == "none":
            return data

1018
        # if "tool_choice" is specified -- validation
1019
        if "tool_choice" in data and data["tool_choice"] is not None:
1020
            # ensure that if "tool choice" is specified, tools are present
1021
            if "tools" not in data or data["tools"] is None:
1022
                raise ValueError("When using `tool_choice`, `tools` must be set.")
1023
1024

            # make sure that tool choice is either a named tool
1025
            # OR that it's set to "auto" or "required"
1026
1027
1028
            if data["tool_choice"] not in ["auto", "required"] and not isinstance(
                data["tool_choice"], dict
            ):
1029
                raise ValueError(
1030
1031
1032
                    f"Invalid value for `tool_choice`: {data['tool_choice']}! "
                    'Only named tools, "none", "auto" or "required" '
                    "are supported."
1033
                )
1034

1035
1036
1037
            # if tool_choice is "required" but the "tools" list is empty,
            # override the data to behave like "none" to align with
            # OpenAI’s behavior.
1038
1039
1040
1041
1042
            if (
                data["tool_choice"] == "required"
                and isinstance(data["tools"], list)
                and len(data["tools"]) == 0
            ):
1043
1044
1045
1046
                data["tool_choice"] = "none"
                del data["tools"]
                return data

1047
1048
            # ensure that if "tool_choice" is specified as an object,
            # it matches a valid tool
1049
1050
            correct_usage_message = (
                'Correct usage: `{"type": "function",'
1051
                ' "function": {"name": "my_function"}}`'
1052
            )
1053
1054
            if isinstance(data["tool_choice"], dict):
                valid_tool = False
1055
1056
                function = data["tool_choice"].get("function")
                if not isinstance(function, dict):
1057
                    raise ValueError(
1058
                        f"Invalid value for `function`: `{function}` in "
1059
1060
                        f"`tool_choice`! {correct_usage_message}"
                    )
1061
                if "name" not in function:
1062
1063
1064
1065
                    raise ValueError(
                        f"Expected field `name` in `function` in "
                        f"`tool_choice`! {correct_usage_message}"
                    )
1066
                function_name = function["name"]
1067
                if not isinstance(function_name, str) or len(function_name) == 0:
1068
                    raise ValueError(
1069
                        f"Invalid `name` in `function`: `{function_name}`"
1070
1071
                        f" in `tool_choice`! {correct_usage_message}"
                    )
1072
                for tool in data["tools"]:
1073
                    if tool["function"]["name"] == function_name:
1074
1075
1076
1077
1078
                        valid_tool = True
                        break
                if not valid_tool:
                    raise ValueError(
                        "The tool specified in `tool_choice` does not match any"
1079
1080
                        " of the specified `tools`"
                    )
1081
1082
        return data

1083
1084
1085
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1086
1087
1088
1089
1090
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1091
1092
        return data

1093
1094
1095
1096
1097
1098
1099
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
1100
1101
1102
1103
1104
1105
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
1106
1107
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1108

1109
class CompletionRequest(OpenAIBaseModel):
1110
1111
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
1112
1113
1114
1115
1116
1117
1118
1119
    model: str | None = None
    prompt: list[int] | list[list[int]] | str | list[str] | None = None
    best_of: int | None = None
    echo: bool | None = False
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: int | None = None
    max_tokens: int | None = 16
1120
    n: int = 1
1121
1122
1123
1124
1125
1126
1127
1128
1129
    presence_penalty: float | None = 0.0
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    suffix: str | None = None
    temperature: float | None = None
    top_p: float | None = None
    user: str | None = None
1130

1131
    # --8<-- [start:completion-sampling-params]
1132
    use_beam_search: bool = False
1133
1134
1135
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
1136
    length_penalty: float = 1.0
1137
    stop_token_ids: list[int] | None = []
1138
1139
1140
1141
1142
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
1143
1144
1145
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    allowed_token_ids: list[int] | None = None
    prompt_logprobs: int | None = None
1146
    # --8<-- [end:completion-sampling-params]
1147

1148
    # --8<-- [start:completion-extra-params]
1149
    prompt_embeds: bytes | list[bytes] | None = None
1150
1151
    add_special_tokens: bool = Field(
        default=True,
1152
        description=(
1153
            "If true (the default), special tokens (e.g. BOS) will be added to "
1154
1155
            "the prompt."
        ),
1156
    )
1157
    response_format: AnyResponseFormat | None = Field(
1158
        default=None,
1159
1160
1161
1162
1163
        description=(
            "Similar to chat completion, this parameter specifies the format "
            "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
            ", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
        ),
1164
    )
1165
    structured_outputs: StructuredOutputsParams | None = Field(
1166
        default=None,
1167
        description="Additional kwargs for structured outputs",
1168
    )
1169
    guided_json: str | dict | BaseModel | None = Field(
1170
1171
1172
1173
        default=None,
        description=(
            "`guided_json` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1174
1175
            "Please pass `json` to `structured_outputs` instead."
        ),
1176
    )
1177
    guided_regex: str | None = Field(
1178
1179
1180
1181
        default=None,
        description=(
            "`guided_regex` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1182
1183
            "Please pass `regex` to `structured_outputs` instead."
        ),
1184
    )
1185
    guided_choice: list[str] | None = Field(
1186
1187
1188
1189
        default=None,
        description=(
            "`guided_choice` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1190
1191
            "Please pass `choice` to `structured_outputs` instead."
        ),
1192
    )
1193
    guided_grammar: str | None = Field(
1194
1195
1196
1197
        default=None,
        description=(
            "`guided_grammar` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1198
1199
            "Please pass `grammar` to `structured_outputs` instead."
        ),
1200
    )
1201
1202
1203
1204
    structural_tag: str | None = Field(
        default=None,
        description=("If specified, the output will follow the structural tag schema."),
    )
1205
    guided_decoding_backend: str | None = Field(
1206
1207
1208
1209
        default=None,
        description=(
            "`guided_decoding_backend` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
1210
1211
            "Please remove it from your request."
        ),
1212
    )
1213
    guided_whitespace_pattern: str | None = Field(
1214
1215
1216
1217
1218
1219
1220
        default=None,
        description=(
            "`guided_whitespace_pattern` is deprecated. "
            "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. "
            "Please pass `whitespace_pattern` to `structured_outputs` instead."
        ),
    )
1221
1222
1223
1224
1225
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1226
1227
            "if the served model does not use priority scheduling."
        ),
1228
    )
1229
1230
1231
1232
1233
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1234
1235
            "through out the inference process and return in response."
        ),
1236
    )
1237
    logits_processors: LogitsProcessors | None = Field(
1238
1239
1240
1241
1242
1243
1244
1245
1246
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
1247
1248
1249
            "{'param': 'value'}}."
        ),
    )
1250

1251
    return_tokens_as_token_ids: bool | None = Field(
1252
1253
1254
1255
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
1256
1257
1258
            "that are not JSON-encodable can be identified."
        ),
    )
1259
    return_token_ids: bool | None = Field(
1260
1261
1262
1263
1264
1265
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
1266
1267
1268
            "need to map generated text back to input tokens."
        ),
    )
1269

1270
    cache_salt: str | None = Field(
1271
1272
1273
1274
1275
1276
1277
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
1278
1279
1280
            "to 256 bit). Not supported by vLLM engine V0."
        ),
    )
1281

1282
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
1283
        default=None,
1284
1285
        description="KVTransfer parameters used for disaggregated serving.",
    )
Robert Shaw's avatar
Robert Shaw committed
1286

1287
    vllm_xargs: dict[str, str | int | float] | None = Field(
1288
        default=None,
1289
1290
1291
1292
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
1293
1294
    )

1295
    # --8<-- [end:completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
1296

1297
1298
1299
1300
1301
    # Default sampling parameters for completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
1302
        "top_k": 0,
1303
1304
1305
1306
        "min_p": 0.0,
    }

    def to_beam_search_params(
1307
1308
        self,
        max_tokens: int,
1309
        default_sampling_params: dict | None = None,
1310
1311
1312
    ) -> BeamSearchParams:
        if default_sampling_params is None:
            default_sampling_params = {}
1313
        n = self.n if self.n is not None else 1
1314
1315
1316

        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get("temperature", 1.0)
1317
1318
1319
1320
1321
1322

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
1323
            length_penalty=self.length_penalty,
1324
1325
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
1326

1327
    def to_sampling_params(
1328
        self,
1329
        max_tokens: int,
1330
1331
        logits_processor_pattern: str | None,
        default_sampling_params: dict | None = None,
1332
    ) -> SamplingParams:
1333
1334
        if default_sampling_params is None:
            default_sampling_params = {}
1335

1336
1337
1338
1339
1340
1341
1342
1343
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
1344
1345
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
1346
1347
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
1348
1349
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
1350
1351
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
1352
1353
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
1354
1355
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
1356
1357
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
1358

1359
1360
1361
1362
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

1363
1364
        echo_without_generation = self.echo and self.max_tokens == 0

1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
        guided_json_object = None
        if self.response_format is not None:
            if self.response_format.type == "json_object":
                guided_json_object = True
            elif self.response_format.type == "json_schema":
                json_schema = self.response_format.json_schema
                assert json_schema is not None
                self.guided_json = json_schema.json_schema
            elif self.response_format.type == "structural_tag":
                structural_tag = self.response_format
                assert structural_tag is not None and isinstance(
                    structural_tag, StructuralTagResponseFormat
                )
                s_tag_obj = structural_tag.model_dump(by_alias=True)
                self.structural_tag = json.dumps(s_tag_obj)

1381
1382
1383
1384
        # Forward deprecated guided_* parameters to structured_outputs
        if self.structured_outputs is None:
            kwargs = dict[str, Any](
                json=self.guided_json,
1385
                json_object=guided_json_object,
1386
1387
1388
1389
1390
1391
1392
1393
1394
                regex=self.guided_regex,
                choice=self.guided_choice,
                grammar=self.guided_grammar,
                whitespace_pattern=self.guided_whitespace_pattern,
            )
            kwargs = {k: v for k, v in kwargs.items() if v is not None}
            if len(kwargs) > 0:
                self.structured_outputs = StructuredOutputsParams(**kwargs)

1395
1396
1397
1398
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
1399
        return SamplingParams.from_optional(
1400
            n=self.n,
1401
            best_of=self.best_of,
1402
1403
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
1404
1405
1406
1407
1408
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
1409
            seed=self.seed,
1410
1411
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
1412
            logprobs=self.logprobs,
1413
            ignore_eos=self.ignore_eos,
1414
            max_tokens=max_tokens if not echo_without_generation else 1,
1415
            min_tokens=self.min_tokens,
1416
            prompt_logprobs=prompt_logprobs,
1417
            skip_special_tokens=self.skip_special_tokens,
1418
            spaces_between_special_tokens=self.spaces_between_special_tokens,
1419
            include_stop_str_in_output=self.include_stop_str_in_output,
1420
1421
1422
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
1423
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1424
1425
1426
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
1427
            structured_outputs=self.structured_outputs,
1428
            logit_bias=self.logit_bias,
Robert Shaw's avatar
Robert Shaw committed
1429
            allowed_token_ids=self.allowed_token_ids,
1430
            extra_args=extra_args or None,
1431
        )
1432

1433
1434
    @model_validator(mode="before")
    @classmethod
1435
    def check_structured_outputs_count(cls, data):
1436
        if data.get("structured_outputs", None) is None:
1437
1438
            return data

1439
        structured_outputs_kwargs = data["structured_outputs"]
1440
1441
        count = sum(
            structured_outputs_kwargs.get(k) is not None
1442
1443
            for k in ("json", "regex", "choice")
        )
1444
        if count > 1:
1445
            raise ValueError(
1446
                "You can only use one kind of constraints for structured "
1447
1448
                "outputs ('json', 'regex' or 'choice')."
            )
1449
1450
        return data

1451
1452
1453
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
1454
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
1455
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
1456
                raise ValueError(
1457
1458
                    "`prompt_logprobs` are not available when `stream=True`."
                )
1459

1460
            if prompt_logprobs < 0 and prompt_logprobs != -1:
1461
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
1462
            if prompt_logprobs == -1 and not envs.VLLM_USE_V1:
1463
1464
1465
                raise ValueError(
                    "`prompt_logprobs=-1` is only supported with vLLM engine V1."
                )
1466
1467
1468
        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
            raise ValueError("`logprobs` must be a positive value.")

1469
1470
        return data

1471
1472
1473
1474
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
1475
            raise ValueError("Stream options can only be defined when `stream=True`.")
1476

1477
1478
        return data

1479
1480
1481
    @model_validator(mode="before")
    @classmethod
    def validate_prompt_and_prompt_embeds(cls, data):
1482
1483
1484
        prompt = data.get("prompt")
        prompt_embeds = data.get("prompt_embeds")

1485
1486
1487
1488
        prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
        embeds_is_empty = prompt_embeds is None or (
            isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
        )
1489
1490

        if prompt_is_empty and embeds_is_empty:
1491
            raise ValueError(
1492
1493
1494
                "Either prompt or prompt_embeds must be provided and non-empty."
            )

1495
1496
        return data

1497
1498
1499
1500
1501
1502
1503
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
1504
1505
1506
1507
1508
1509
                    "this instance of vLLM, which uses engine V0."
                )
            if not isinstance(data["cache_salt"], str) or not data["cache_salt"]:
                raise ValueError(
                    "Parameter 'cache_salt' must be a non-empty string if provided."
                )
1510
1511
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1512

1513
class EmbeddingCompletionRequest(OpenAIBaseModel):
1514
1515
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/embeddings
1516
1517
    model: str | None = None
    input: list[int] | list[list[int]] | str | list[str]
1518
    encoding_format: Literal["float", "base64"] = "float"
1519
1520
1521
    dimensions: int | None = None
    user: str | None = None
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
1522

1523
    # --8<-- [start:embedding-extra-params]
1524
1525
1526
1527
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
1528
1529
            "the prompt."
        ),
1530
    )
1531
1532
1533
1534
1535
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1536
1537
            "if the served model does not use priority scheduling."
        ),
1538
    )
1539
1540
1541
1542
1543
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1544
1545
            "through out the inference process and return in response."
        ),
1546
    )
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
    normalize: bool | None = Field(
        default=None,
        description="Whether to normalize the embeddings outputs. Default is True.",
    )
    embed_dtype: str = Field(
        default="float32",
        description=(
            "What dtype to use for base64 encoding. Default to using "
            "float32 for base64 encoding to match the OpenAI python client behavior."
        ),
    )
1558
    # --8<-- [end:embedding-extra-params]
1559

1560
    def to_pooling_params(self):
1561
1562
1563
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            dimensions=self.dimensions,
1564
1565
            normalize=self.normalize,
        )
1566
1567


1568
class EmbeddingChatRequest(OpenAIBaseModel):
1569
    model: str | None = None
1570
    messages: list[ChatCompletionMessageParam]
1571
1572

    encoding_format: Literal["float", "base64"] = "float"
1573
1574
1575
    dimensions: int | None = None
    user: str | None = None
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
1576

1577
    # --8<-- [start:chat-embedding-extra-params]
1578
1579
    add_generation_prompt: bool = Field(
        default=False,
1580
1581
1582
1583
1584
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
1585
1586
    )

1587
1588
1589
1590
1591
1592
1593
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
1594
1595
            "default)."
        ),
1596
    )
1597
    chat_template: str | None = Field(
1598
1599
1600
1601
1602
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1603
1604
            "does not define one."
        ),
1605
    )
1606
    chat_template_kwargs: dict[str, Any] | None = Field(
1607
        default=None,
1608
1609
        description=(
            "Additional keyword args to pass to the template renderer. "
1610
1611
            "Will be accessible by the chat template."
        ),
1612
    )
1613
    mm_processor_kwargs: dict[str, Any] | None = Field(
1614
1615
1616
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
1617
1618
1619
1620
1621
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1622
1623
            "if the served model does not use priority scheduling."
        ),
1624
    )
1625
1626
1627
1628
1629
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1630
1631
            "through out the inference process and return in response."
        ),
1632
    )
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
    normalize: bool | None = Field(
        default=None,
        description="Whether to normalize the embeddings outputs. Default is True.",
    )
    embed_dtype: str = Field(
        default="float32",
        description=(
            "Which dtype to use for base64 encoding. Defaults to float32 "
            "to match OpenAI API."
        ),
    )
1644
    # --8<-- [end:chat-embedding-extra-params]
1645
1646
1647
1648

    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1649
1650
1651
1652
1653
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1654
1655
1656
        return data

    def to_pooling_params(self):
1657
1658
1659
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            dimensions=self.dimensions,
1660
1661
            normalize=self.normalize,
        )
1662
1663


1664
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
1665

1666
1667
PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest
1668
1669
1670
1671
1672

T = TypeVar("T")


class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
1673
    model: str | None = None
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685

    priority: int = Field(default=0)
    """
    The priority of the request (lower means earlier handling;
    default: 0). Any priority other than 0 will raise an error
    if the served model does not use priority scheduling.
    """
    data: T
    """
    When using plugins IOProcessor plugins, the actual input is processed
    by the plugin itself. Hence, we use a generic type for the request data
    """
1686
    activation: bool = False
1687

1688
1689
1690
1691
1692
1693
1694
1695
    embed_dtype: str = Field(
        default="float32",
        description=(
            "What dtype to use for base64 encoding. Default to using "
            "float32 for base64 encoding to match the OpenAI python client behavior."
        ),
    )

1696
    def to_pooling_params(self):
1697
        return PoolingParams(task="token_classify", activation=self.activation)
1698
1699
1700


class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
1701
    request_id: str | None = None
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
    """
    The request_id associated with this response
    """
    created_at: int = Field(default_factory=lambda: int(time.time()))

    data: T
    """
    When using plugins IOProcessor plugins, the actual output is generated
    by the plugin itself. Hence, we use a generic type for the response data
    """


1714
1715
1716
PoolingRequest: TypeAlias = (
    PoolingCompletionRequest | PoolingChatRequest | IOProcessorRequest
)
1717

1718

1719
class ScoreRequest(OpenAIBaseModel):
1720
1721
1722
1723
    model: str | None = None
    text_1: list[str] | str | ScoreMultiModalParam
    text_2: list[str] | str | ScoreMultiModalParam
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
1724

1725
    # --8<-- [start:score-extra-params]
1726

1727
    mm_processor_kwargs: dict[str, Any] | None = Field(
1728
1729
1730
1731
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1732
1733
1734
1735
1736
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1737
1738
            "if the served model does not use priority scheduling."
        ),
1739
    )
1740

1741
    activation: bool | None = None
1742

1743
    # --8<-- [end:score-extra-params]
1744

1745
    def to_pooling_params(self):
1746
1747
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1748
1749
            activation=self.activation,
        )
1750
1751


1752
class RerankRequest(OpenAIBaseModel):
1753
1754
1755
    model: str | None = None
    query: str | ScoreMultiModalParam
    documents: list[str] | ScoreMultiModalParam
1756
    top_n: int = Field(default_factory=lambda: 0)
1757
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
1758

1759
    # --8<-- [start:rerank-extra-params]
1760

1761
    mm_processor_kwargs: dict[str, Any] | None = Field(
1762
1763
1764
1765
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1766
1767
1768
1769
1770
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1771
1772
            "if the served model does not use priority scheduling."
        ),
1773
    )
1774

1775
    activation: bool | None = None
1776

1777
    # --8<-- [end:rerank-extra-params]
1778

1779
    def to_pooling_params(self):
1780
1781
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1782
1783
            activation=self.activation,
        )
1784
1785
1786


class RerankDocument(BaseModel):
1787
1788
    text: str | None = None
    multi_modal: ScoreContentPartParam | None = None
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804


class RerankResult(BaseModel):
    index: int
    document: RerankDocument
    relevance_score: float


class RerankUsage(BaseModel):
    total_tokens: int


class RerankResponse(OpenAIBaseModel):
    id: str
    model: str
    usage: RerankUsage
1805
    results: list[RerankResult]
1806
1807


1808
class CompletionLogProbs(OpenAIBaseModel):
1809
    text_offset: list[int] = Field(default_factory=list)
1810
    token_logprobs: list[float | None] = Field(default_factory=list)
1811
    tokens: list[str] = Field(default_factory=list)
1812
    top_logprobs: list[dict[str, float] | None] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
1813
1814


1815
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1816
1817
    index: int
    text: str
1818
1819
1820
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1821
1822
1823
1824
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1825
1826
            "including encountering the EOS token"
        ),
1827
    )
1828
1829
1830
    token_ids: list[int] | None = None  # For response
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None  # For prompt
Zhuohan Li's avatar
Zhuohan Li committed
1831
1832


1833
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1834
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
1835
    object: Literal["text_completion"] = "text_completion"
Zhuohan Li's avatar
Zhuohan Li committed
1836
1837
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1838
    choices: list[CompletionResponseChoice]
1839
1840
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1841
    usage: UsageInfo
1842
1843

    # vLLM-specific fields that are not in OpenAI spec
1844
    kv_transfer_params: dict[str, Any] | None = Field(
1845
1846
        default=None, description="KVTransfer parameters."
    )
Zhuohan Li's avatar
Zhuohan Li committed
1847
1848


1849
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1850
1851
    index: int
    text: str
1852
1853
1854
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1855
1856
1857
1858
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1859
1860
            "including encountering the EOS token"
        ),
1861
    )
1862
1863
    # not part of the OpenAI spec but for tracing the tokens
    # prompt tokens is put into choice to align with CompletionResponseChoice
1864
1865
    prompt_token_ids: list[int] | None = None
    token_ids: list[int] | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1866
1867


1868
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1869
1870
1871
1872
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1873
    choices: list[CompletionResponseStreamChoice]
1874
    usage: UsageInfo | None = Field(default=None)
1875
1876


1877
class EmbeddingResponseData(OpenAIBaseModel):
1878
1879
    index: int
    object: str = "embedding"
1880
    embedding: list[float] | str
1881
1882


1883
class EmbeddingResponse(OpenAIBaseModel):
1884
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
1885
1886
1887
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1888
    data: list[EmbeddingResponseData]
1889
1890
1891
    usage: UsageInfo


1892
1893
1894
class PoolingResponseData(OpenAIBaseModel):
    index: int
    object: str = "pooling"
1895
    data: list[list[float]] | list[float] | str
1896
1897
1898
1899
1900
1901
1902


class PoolingResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1903
    data: list[PoolingResponseData]
1904
1905
1906
    usage: UsageInfo


1907
1908
1909
class ScoreResponseData(OpenAIBaseModel):
    index: int
    object: str = "score"
1910
    score: float
1911
1912
1913
1914
1915
1916
1917


class ScoreResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1918
    data: list[ScoreResponseData]
1919
1920
1921
    usage: UsageInfo


1922
class ClassificationRequest(OpenAIBaseModel):
1923
1924
1925
1926
    model: str | None = None
    input: list[str] | str
    truncate_prompt_tokens: int | None = None
    user: str | None = None
1927

1928
    # --8<-- [start:classification-extra-params]
1929
1930
1931
1932
1933
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1934
1935
            "if the served model does not use priority scheduling."
        ),
1936
1937
    )

1938
    activation: bool | None = None
1939

1940
    # --8<-- [end:classification-extra-params]
1941
1942

    def to_pooling_params(self):
1943
1944
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1945
1946
            activation=self.activation,
        )
1947
1948
1949
1950


class ClassificationData(OpenAIBaseModel):
    index: int
1951
    label: str | None
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
    probs: list[float]
    num_classes: int


class ClassificationResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[ClassificationData]
    usage: UsageInfo


1965
1966
1967
1968
1969
1970
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
1971
    id: str = Field(default_factory=make_tool_call_id)
1972
1973
1974
1975
    type: Literal["function"] = "function"
    function: FunctionCall


1976
class DeltaFunctionCall(BaseModel):
1977
1978
    name: str | None = None
    arguments: str | None = None
1979
1980
1981
1982


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
1983
1984
    id: str | None = None
    type: Literal["function"] | None = None
1985
    index: int
1986
    function: DeltaFunctionCall | None = None
1987
1988
1989
1990
1991
1992
1993


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
1994
    tool_calls: list[ToolCall]
1995
1996
1997

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
1998
    content: str | None = None
1999
2000


2001
class ChatMessage(OpenAIBaseModel):
2002
    role: str
2003
2004
2005
2006
2007
    content: str | None = None
    refusal: str | None = None
    annotations: OpenAIAnnotation | None = None
    audio: OpenAIChatCompletionAudio | None = None
    function_call: FunctionCall | None = None
2008
    tool_calls: list[ToolCall] = Field(default_factory=list)
2009

2010
    # vLLM-specific fields that are not in OpenAI spec
2011
    reasoning_content: str | None = None
2012

2013

2014
2015
2016
class ChatCompletionLogProb(OpenAIBaseModel):
    token: str
    logprob: float = -9999.0
2017
    bytes: list[int] | None = None
2018
2019
2020


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
2021
2022
    # Workaround: redefine fields name cache so that it's not
    # shared with the super class.
2023
    field_names: ClassVar[set[str] | None] = None
2024
    top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
2025
2026
2027


class ChatCompletionLogProbs(OpenAIBaseModel):
2028
    content: list[ChatCompletionLogProbsContent] | None = None
2029
2030


2031
class ChatCompletionResponseChoice(OpenAIBaseModel):
2032
2033
    index: int
    message: ChatMessage
2034
    logprobs: ChatCompletionLogProbs | None = None
2035
    # per OpenAI spec this is the default
2036
    finish_reason: str | None = "stop"
2037
    # not part of the OpenAI spec but included in vLLM for legacy reasons
2038
    stop_reason: int | str | None = None
2039
2040
    # not part of the OpenAI spec but is useful for tracing the tokens
    # in agent scenarios
2041
    token_ids: list[int] | None = None
2042
2043


2044
class ChatCompletionResponse(OpenAIBaseModel):
2045
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
2046
    object: Literal["chat.completion"] = "chat.completion"
2047
2048
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
2049
    choices: list[ChatCompletionResponseChoice]
2050
2051
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
2052
    usage: UsageInfo
2053
2054

    # vLLM-specific fields that are not in OpenAI spec
2055
2056
2057
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None
    kv_transfer_params: dict[str, Any] | None = Field(
2058
2059
        default=None, description="KVTransfer parameters."
    )
2060
2061


2062
class DeltaMessage(OpenAIBaseModel):
2063
2064
2065
    role: str | None = None
    content: str | None = None
    reasoning_content: str | None = None
2066
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
2067
2068


2069
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
2070
2071
    index: int
    delta: DeltaMessage
2072
2073
2074
    logprobs: ChatCompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = None
2075
    # not part of the OpenAI spec but for tracing the tokens
2076
    token_ids: list[int] | None = None
2077
2078


2079
class ChatCompletionStreamResponse(OpenAIBaseModel):
2080
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
2081
    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
2082
2083
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
2084
    choices: list[ChatCompletionResponseStreamChoice]
2085
    usage: UsageInfo | None = Field(default=None)
2086
    # not part of the OpenAI spec but for tracing the tokens
2087
    prompt_token_ids: list[int] | None = None
2088
2089


2090
2091
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
2092
2093
    finish_reason: str | None = None
    stop_reason: int | str | None = None
2094
2095
2096
2097
2098
2099
2100
2101


class TranscriptionStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
    object: Literal["transcription.chunk"] = "transcription.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranscriptionResponseStreamChoice]
2102
    usage: UsageInfo | None = Field(default=None)
2103
2104


2105
2106
class InputTokensDetails(OpenAIBaseModel):
    cached_tokens: int
2107
2108
    input_tokens_per_turn: list[int] = Field(default_factory=list)
    cached_tokens_per_turn: list[int] = Field(default_factory=list)
2109
2110
2111


class OutputTokensDetails(OpenAIBaseModel):
2112
2113
    reasoning_tokens: int = 0
    tool_output_tokens: int = 0
2114
2115
    output_tokens_per_turn: list[int] = Field(default_factory=list)
    tool_output_tokens_per_turn: list[int] = Field(default_factory=list)
2116
2117
2118
2119
2120
2121
2122
2123


class ResponseUsage(OpenAIBaseModel):
    input_tokens: int
    input_tokens_details: InputTokensDetails
    output_tokens: int
    output_tokens_details: OutputTokensDetails
    total_tokens: int
2124
2125


2126
2127
2128
2129
2130
2131
def serialize_message(msg):
    """
    Serializes a single message
    """
    if isinstance(msg, dict):
        return msg
2132
    elif hasattr(msg, "to_dict"):
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
        return msg.to_dict()
    else:
        # fallback to pyandic dump
        return msg.model_dump_json()


def serialize_messages(msgs):
    """
    Serializes multiple messages
    """
    return [serialize_message(msg) for msg in msgs] if msgs else None


2146
2147
2148
2149
class ResponsesResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
    created_at: int = Field(default_factory=lambda: int(time.time()))
    # error: Optional[ResponseError] = None
2150
2151
2152
    incomplete_details: IncompleteDetails | None = None
    instructions: str | None = None
    metadata: Metadata | None = None
2153
2154
    model: str
    object: Literal["response"] = "response"
2155
    output: list[ResponseOutputItem]
2156
2157
2158
2159
2160
2161
2162
    parallel_tool_calls: bool
    temperature: float
    tool_choice: ToolChoice
    tools: list[Tool]
    top_p: float
    background: bool
    max_output_tokens: int
2163
2164
2165
2166
    max_tool_calls: int | None = None
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
2167
2168
    service_tier: Literal["auto", "default", "flex", "scale", "priority"]
    status: ResponseStatus
2169
2170
    text: ResponseTextConfig | None = None
    top_logprobs: int | None = None
2171
    truncation: Literal["auto", "disabled"]
2172
2173
    usage: ResponseUsage | None = None
    user: str | None = None
2174

2175
2176
2177
2178
    # --8<-- [start:responses-extra-params]
    # These are populated when enable_response_messages is set to True
    # NOTE: custom serialization is needed
    # see serialize_input_messages and serialize_output_messages
2179
2180
    input_messages: list[ChatCompletionMessageParam] | None = None
    output_messages: list[ChatCompletionMessageParam] | None = None
2181
2182
2183
2184
2185
2186
2187
    # --8<-- [end:responses-extra-params]

    # NOTE: openAI harmony doesn't serialize TextContent properly,
    # TODO: this fixes for TextContent, but need to verify for tools etc
    # https://github.com/openai/harmony/issues/78
    @field_serializer("output_messages", when_used="json")
    def serialize_output_messages(self, msgs, _info):
2188
        return serialize_messages(msgs)
2189
2190
2191
2192
2193

    # NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it
    # https://github.com/openai/harmony/issues/78
    @field_serializer("input_messages", when_used="json")
    def serialize_input_messages(self, msgs, _info):
2194
        return serialize_messages(msgs)
2195

2196
2197
2198
2199
2200
2201
2202
2203
2204
    @classmethod
    def from_request(
        cls,
        request: ResponsesRequest,
        sampling_params: SamplingParams,
        model_name: str,
        created_time: int,
        output: list[ResponseOutputItem],
        status: ResponseStatus,
2205
2206
2207
        usage: ResponseUsage | None = None,
        input_messages: list[ChatCompletionMessageParam] | None = None,
        output_messages: list[ChatCompletionMessageParam] | None = None,
2208
    ) -> "ResponsesResponse":
2209
        incomplete_details: IncompleteDetails | None = None
2210
2211
        if status == "incomplete":
            incomplete_details = IncompleteDetails(reason="max_output_tokens")
2212
2213
2214
        # TODO: implement the other reason for incomplete_details,
        # which is content_filter
        # incomplete_details = IncompleteDetails(reason='content_filter')
2215
2216
2217
        return cls(
            id=request.request_id,
            created_at=created_time,
2218
            incomplete_details=incomplete_details,
2219
2220
2221
2222
            instructions=request.instructions,
            metadata=request.metadata,
            model=model_name,
            output=output,
2223
2224
            input_messages=input_messages,
            output_messages=output_messages,
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
            parallel_tool_calls=request.parallel_tool_calls,
            temperature=sampling_params.temperature,
            tool_choice=request.tool_choice,
            tools=request.tools,
            top_p=sampling_params.top_p,
            background=request.background,
            max_output_tokens=sampling_params.max_tokens,
            max_tool_calls=request.max_tool_calls,
            previous_response_id=request.previous_response_id,
            prompt=request.prompt,
            reasoning=request.reasoning,
            service_tier=request.service_tier,
            status=status,
            text=request.text,
            top_logprobs=sampling_params.logprobs,
            truncation=request.truncation,
            user=request.user,
            usage=usage,
        )


2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartDoneEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.done"]
    """The type of the event. Always `response.reasoning_part.done`."""


# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartAddedEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.added"]
    """The type of the event. Always `response.reasoning_part.added`."""


2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
# vLLM Streaming Events
# Note: we override the response type with the vLLM ResponsesResponse type
class ResponseCompletedEvent(OpenAIResponseCompletedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseCreatedEvent(OpenAIResponseCreatedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseInProgressEvent(OpenAIResponseInProgressEvent):
    response: ResponsesResponse  # type: ignore[override]


2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
StreamingResponsesResponse: TypeAlias = (
    ResponseCreatedEvent
    | ResponseInProgressEvent
    | ResponseCompletedEvent
    | ResponseOutputItemAddedEvent
    | ResponseOutputItemDoneEvent
    | ResponseContentPartAddedEvent
    | ResponseContentPartDoneEvent
    | ResponseReasoningTextDeltaEvent
    | ResponseReasoningTextDoneEvent
    | ResponseReasoningPartAddedEvent
    | ResponseReasoningPartDoneEvent
    | ResponseCodeInterpreterCallInProgressEvent
    | ResponseCodeInterpreterCallCodeDeltaEvent
    | ResponseWebSearchCallInProgressEvent
    | ResponseWebSearchCallSearchingEvent
    | ResponseWebSearchCallCompletedEvent
    | ResponseCodeInterpreterCallCodeDoneEvent
    | ResponseCodeInterpreterCallInterpretingEvent
    | ResponseCodeInterpreterCallCompletedEvent
)
2325

2326
2327
2328
BatchRequestInputBody: TypeAlias = (
    ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest
)
2329
2330


2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
class BatchRequestInput(OpenAIBaseModel):
    """
    The per-line object of the batch input file.

    NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
    """

    # A developer-provided per-request id that will be used to match outputs to
    # inputs. Must be unique for each request in a batch.
    custom_id: str

    # The HTTP method to be used for the request. Currently only POST is
    # supported.
    method: str

    # The OpenAI API relative URL to be used for the request. Currently
    # /v1/chat/completions is supported.
    url: str

2350
    # The parameters of the request.
2351
    body: BatchRequestInputBody
2352

2353
    @field_validator("body", mode="plain")
2354
2355
2356
    @classmethod
    def check_type_for_url(cls, value: Any, info: ValidationInfo):
        # Use url to disambiguate models
2357
        url: str = info.data["url"]
2358
2359
2360
2361
        if url == "/v1/chat/completions":
            return ChatCompletionRequest.model_validate(value)
        if url == "/v1/embeddings":
            return TypeAdapter(EmbeddingRequest).validate_python(value)
2362
        if url.endswith("/score"):
2363
            return ScoreRequest.model_validate(value)
2364
2365
2366
        if url.endswith("/rerank"):
            return RerankRequest.model_validate(value)
        return TypeAdapter(BatchRequestInputBody).validate_python(value)
2367

2368

2369
2370
2371
2372
2373
2374
2375
2376
class BatchResponseData(OpenAIBaseModel):
    # HTTP status code of the response.
    status_code: int = 200

    # An unique identifier for the API request.
    request_id: str

    # The body of the response.
2377
2378
2379
2380
2381
2382
2383
    body: (
        ChatCompletionResponse
        | EmbeddingResponse
        | ScoreResponse
        | RerankResponse
        | None
    ) = None
2384
2385


2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
class BatchRequestOutput(OpenAIBaseModel):
    """
    The per-line object of the batch output and error files
    """

    id: str

    # A developer-provided per-request id that will be used to match outputs to
    # inputs.
    custom_id: str

2397
    response: BatchResponseData | None
2398
2399
2400

    # For requests that failed with a non-HTTP error, this will contain more
    # information on the cause of the failure.
2401
    error: Any | None
2402
2403


2404
class TokenizeCompletionRequest(OpenAIBaseModel):
2405
    model: str | None = None
2406
2407
    prompt: str

2408
2409
2410
2411
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
2412
2413
            "the prompt."
        ),
2414
    )
2415
    return_token_strs: bool | None = Field(
2416
        default=False,
2417
2418
2419
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
2420
    )
2421
2422
2423


class TokenizeChatRequest(OpenAIBaseModel):
2424
    model: str | None = None
2425
    messages: list[ChatCompletionMessageParam]
2426

2427
2428
    add_generation_prompt: bool = Field(
        default=True,
2429
2430
2431
2432
2433
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
2434
    )
2435
    return_token_strs: bool | None = Field(
2436
        default=False,
2437
2438
2439
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
2440
    )
2441
2442
    continue_final_message: bool = Field(
        default=False,
2443
2444
2445
2446
2447
2448
2449
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
2450
2451
2452
2453
2454
2455
2456
2457
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
2458
2459
            "default)."
        ),
2460
    )
2461
    chat_template: str | None = Field(
2462
2463
2464
2465
2466
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
2467
2468
            "does not define one."
        ),
2469
    )
2470
    chat_template_kwargs: dict[str, Any] | None = Field(
2471
        default=None,
2472
2473
        description=(
            "Additional keyword args to pass to the template renderer. "
2474
2475
            "Will be accessible by the chat template."
        ),
2476
    )
2477
    mm_processor_kwargs: dict[str, Any] | None = Field(
2478
2479
2480
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
2481
    tools: list[ChatCompletionToolsParam] | None = Field(
2482
2483
2484
        default=None,
        description=("A list of tools the model may call."),
    )
2485

2486
2487
2488
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
2489
2490
2491
2492
2493
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
2494
2495
        return data

2496

2497
TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest
2498
2499
2500
2501
2502


class TokenizeResponse(OpenAIBaseModel):
    count: int
    max_model_len: int
2503
    tokens: list[int]
2504
    token_strs: list[str] | None = None
2505
2506
2507


class DetokenizeRequest(OpenAIBaseModel):
2508
    model: str | None = None
2509
    tokens: list[int]
2510
2511
2512
2513


class DetokenizeResponse(OpenAIBaseModel):
    prompt: str
2514
2515


2516
2517
class TokenizerInfoResponse(OpenAIBaseModel):
    """
2518
    Response containing tokenizer configuration
2519
2520
2521
2522
2523
2524
2525
    equivalent to tokenizer_config.json
    """

    model_config = ConfigDict(extra="allow")
    tokenizer_class: str


2526
class LoadLoRAAdapterRequest(BaseModel):
2527
2528
2529
2530
    lora_name: str
    lora_path: str


2531
class UnloadLoRAAdapterRequest(BaseModel):
2532
    lora_name: str
2533
    lora_int_id: int | None = Field(default=None)
2534
2535
2536


## Protocols for Audio
2537
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
2538
2539
2540
2541


class TranscriptionRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
2542
    # https://platform.openai.com/docs/api-reference/audio/createTranscription
2543
2544
2545
2546
2547
2548
2549

    file: UploadFile
    """
    The audio file object (not file name) to transcribe, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2550
    model: str | None = None
2551
2552
2553
    """ID of the model to use.
    """

2554
    language: str | None = None
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
    """The language of the input audio.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy and latency.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    ## TODO (varun) : Support if set to 0, certain thresholds are met !!

2578
    timestamp_granularities: list[Literal["word", "segment"]] = Field(
2579
2580
        alias="timestamp_granularities[]", default=[]
    )
2581
2582
2583
2584
2585
2586
2587
2588
    """The timestamp granularities to populate for this transcription.

    `response_format` must be set `verbose_json` to use timestamp granularities.
    Either or both of these options are supported: `word`, or `segment`. Note:
    There is no additional latency for segment timestamps, but generating word
    timestamps incurs additional latency.
    """

2589
    stream: bool | None = False
2590
    """When set, it will enable output to be streamed in a similar fashion
2591
    as the Chat Completion endpoint.
2592
    """
2593
    # --8<-- [start:transcription-extra-params]
2594
    # Flattened stream option to simplify form data.
2595
2596
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
2597

2598
    vllm_xargs: dict[str, str | int | float] | None = Field(
2599
        default=None,
2600
2601
2602
2603
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
2604
    )
2605
    # --8<-- [end:transcription-extra-params]
2606

2607
    to_language: str | None = None
2608
2609
    """The language of the output audio we transcribe to.

2610
    Please note that this is not currently used by supported models at this
2611
2612
2613
    time, but it is a placeholder for future use, matching translation api.
    """

2614
    # --8<-- [start:transcription-sampling-params]
2615
2616
2617
2618
2619
2620
2621
2622
2623
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """

2624
    top_p: float | None = None
2625
    """Enables nucleus (top-p) sampling, where tokens are selected from the
2626
2627
2628
    smallest possible set whose cumulative probability exceeds `p`.
    """

2629
    top_k: int | None = None
2630
2631
    """Limits sampling to the `k` most probable tokens at each step."""

2632
    min_p: float | None = None
2633
    """Filters out tokens with a probability lower than `min_p`, ensuring a
2634
2635
2636
    minimum likelihood threshold during sampling.
    """

2637
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2638
2639
    """The seed to use for sampling."""

2640
    frequency_penalty: float | None = 0.0
2641
2642
    """The frequency penalty to use for sampling."""

2643
    repetition_penalty: float | None = None
2644
2645
    """The repetition penalty to use for sampling."""

2646
    presence_penalty: float | None = 0.0
2647
    """The presence penalty to use for sampling."""
2648
    # --8<-- [end:transcription-sampling-params]
2649

2650
2651
    # Default sampling parameters for transcription requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
2652
2653
2654
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
2655
        "top_k": 0,
2656
        "min_p": 0.0,
2657
2658
2659
    }

    def to_sampling_params(
2660
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2661
    ) -> SamplingParams:
2662
2663
2664
2665
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
2666

2667
2668
2669
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2670
2671
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2672
2673
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
2674
2675
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
2676
2677
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
2678
2679
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
2680
2681
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
2682
2683
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
2684
2685
2686
2687

        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )

        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=repetition_penalty,
            presence_penalty=self.presence_penalty,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
            extra_args=self.vllm_xargs,
        )
2706
2707
2708

    @model_validator(mode="before")
    @classmethod
2709
2710
2711
2712
2713
2714
2715
    def validate_transcription_request(cls, data):
        if isinstance(data.get("file"), str):
            raise HTTPException(
                status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
                detail="Expected 'file' to be a file-like object, not 'str'.",
            )

2716
2717
2718
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2719
            raise ValueError("Stream options can only be defined when `stream=True`.")
2720
2721

        return data
2722
2723
2724


# Transcription response objects
2725
2726
2727
2728
2729
class TranscriptionUsageAudio(OpenAIBaseModel):
    type: Literal["duration"] = "duration"
    seconds: int


2730
2731
2732
class TranscriptionResponse(OpenAIBaseModel):
    text: str
    """The transcribed text."""
2733
    usage: TranscriptionUsageAudio
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784


class TranscriptionWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranscriptionSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

2785
    tokens: list[int]
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
    """Array of token IDs for the text content."""


class TranscriptionResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The transcribed text."""

2799
    segments: list[TranscriptionSegment] | None = None
2800
2801
    """Segments of the transcribed text and their corresponding details."""

2802
    words: list[TranscriptionWord] | None = None
2803
    """Extracted words and their corresponding timestamps."""
2804
2805
2806
2807


class TranslationResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
2808
2809
    finish_reason: str | None = None
    stop_reason: int | str | None = None
2810
2811
2812
2813
2814
2815
2816
2817


class TranslationStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
    object: Literal["translation.chunk"] = "translation.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranslationResponseStreamChoice]
2818
    usage: UsageInfo | None = Field(default=None)
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830


class TranslationRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/audio/createTranslation

    file: UploadFile
    """
    The audio file object (not file name) to translate, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2831
    model: str | None = None
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
    """ID of the model to use.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    # TODO support additional sampling parameters
    # --8<-- [start:translation-sampling-params]
2851
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2852
2853
    """The seed to use for sampling."""

2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """
    # --8<-- [end:translation-sampling-params]

    # --8<-- [start:translation-extra-params]
2865
    language: str | None = None
2866
2867
2868
2869
2870
2871
2872
    """The language of the input audio we translate from.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy.
    """

2873
    to_language: str | None = None
2874
2875
2876
2877
2878
2879
2880
    """The language of the input audio we translate to.

    Please note that this is not supported by all models, refer to the specific
    model documentation for more details.
    For instance, Whisper only supports `to_language=en`.
    """

2881
    stream: bool | None = False
2882
    """Custom field not present in the original OpenAI definition. When set,
2883
    it will enable output to be streamed in a similar fashion as the Chat
2884
    Completion endpoint.
2885
2886
    """
    # Flattened stream option to simplify form data.
2887
2888
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
2889
2890
2891
2892
2893
2894
2895
2896
    # --8<-- [end:translation-extra-params]

    # Default sampling parameters for translation requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "temperature": 0,
    }

    def to_sampling_params(
2897
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2898
    ) -> SamplingParams:
2899
2900
2901
2902
2903
2904
2905
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2906
2907
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2908

2909
2910
2911
2912
2913
2914
2915
2916
        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
        )
2917
2918
2919
2920
2921
2922
2923

    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2924
            raise ValueError("Stream options can only be defined when `stream=True`.")
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997

        return data


# Translation response objects
class TranslationResponse(OpenAIBaseModel):
    text: str
    """The translated text."""


class TranslationWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranslationSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

    tokens: list[int]
    """Array of token IDs for the text content."""


class TranslationResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The translated text."""

2998
    segments: list[TranslationSegment] | None = None
2999
3000
    """Segments of the translated text and their corresponding details."""

3001
    words: list[TranslationWord] | None = None
3002
    """Extracted words and their corresponding timestamps."""