protocol.py 86.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
6
import json
Zhuohan Li's avatar
Zhuohan Li committed
7
import time
8
from http import HTTPStatus
9
from typing import Annotated, Any, ClassVar, Literal, TypeAlias
Zhuohan Li's avatar
Zhuohan Li committed
10

11
import regex as re
12
import torch
13
from fastapi import HTTPException, UploadFile
14
from openai.types.chat.chat_completion_audio import (
15
16
17
    ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
18
19
20
21
22
from openai.types.responses import (
    ResponseCodeInterpreterCallCodeDeltaEvent,
    ResponseCodeInterpreterCallCodeDoneEvent,
    ResponseCodeInterpreterCallCompletedEvent,
    ResponseCodeInterpreterCallInProgressEvent,
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    ResponseCodeInterpreterCallInterpretingEvent,
    ResponseContentPartAddedEvent,
    ResponseContentPartDoneEvent,
    ResponseFunctionToolCall,
    ResponseInputItemParam,
    ResponseOutputItem,
    ResponseOutputItemAddedEvent,
    ResponseOutputItemDoneEvent,
    ResponsePrompt,
    ResponseReasoningTextDeltaEvent,
    ResponseReasoningTextDoneEvent,
    ResponseStatus,
    ResponseWebSearchCallCompletedEvent,
    ResponseWebSearchCallInProgressEvent,
    ResponseWebSearchCallSearchingEvent,
)
39
from openai.types.responses import (
40
41
42
    ResponseCompletedEvent as OpenAIResponseCompletedEvent,
)
from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent
43
from openai.types.responses import (
44
45
    ResponseInProgressEvent as OpenAIResponseInProgressEvent,
)
46
from openai.types.responses.response_reasoning_item import (
47
48
    Content as ResponseReasoningTextContent,
)
49
from openai_harmony import Message as OpenAIHarmonyMessage
50
51
52
53
54

# Backward compatibility for OpenAI client versions
try:  # For older openai versions (< 1.100.0)
    from openai.types.responses import ResponseTextConfig
except ImportError:  # For newer openai versions (>= 1.100.0)
55
    from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
56

57

58
from openai.types.responses.response import IncompleteDetails, ToolChoice
59
60
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
61
62
63
64
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
65
    ValidationError,
66
    field_serializer,
67
68
    model_validator,
)
Zhuohan Li's avatar
Zhuohan Li committed
69

70
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
71
from vllm.logger import init_logger
72
from vllm.logprobs import Logprob
73
74
75
76
77
78
from vllm.sampling_params import (
    BeamSearchParams,
    RequestOutputKind,
    SamplingParams,
    StructuredOutputsParams,
)
79
80
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
81

82
83
logger = init_logger(__name__)

84
_LONG_INFO = torch.iinfo(torch.long)
85

Zhuohan Li's avatar
Zhuohan Li committed
86

87
class OpenAIBaseModel(BaseModel):
88
89
90
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

91
    # Cache class field names
92
    field_names: ClassVar[set[str] | None] = None
93

94
    @model_validator(mode="wrap")
95
    @classmethod
96
97
98
99
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
100
101
        field_names = cls.field_names
        if field_names is None:
102
103
104
105
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
106
                if alias := getattr(field, "alias", None):
107
108
109
110
111
112
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
113
                "The following fields were present in the request but ignored: %s",
114
115
                data.keys() - field_names,
            )
116
        return result
117
118


119
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
120
121
    message: str
    type: str
122
    param: str | None = None
123
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
124
125


126
127
128
129
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


130
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
131
132
133
134
135
136
137
138
139
140
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
141
    group: str | None = None
142
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
143
144


145
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
146
147
148
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
149
    owned_by: str = "vllm"
150
151
152
    root: str | None = None
    parent: str | None = None
    max_model_len: int | None = None
153
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
154
155


156
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
157
    object: str = "list"
158
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
159
160


161
class PromptTokenUsageInfo(OpenAIBaseModel):
162
    cached_tokens: int | None = None
163
164


165
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
166
167
    prompt_tokens: int = 0
    total_tokens: int = 0
168
169
    completion_tokens: int | None = 0
    prompt_tokens_details: PromptTokenUsageInfo | None = None
Zhuohan Li's avatar
Zhuohan Li committed
170
171


172
173
class RequestResponseMetadata(BaseModel):
    request_id: str
174
    final_usage_info: UsageInfo | None = None
175
176


177
178
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
179
    description: str | None = None
180
181
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
182
183
    json_schema: dict[str, Any] | None = Field(default=None, alias="schema")
    strict: bool | None = None
184
185


186
class LegacyStructuralTag(OpenAIBaseModel):
187
188
189
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
190
    structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema")
191
192
193
    end: str


194
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
195
    type: Literal["structural_tag"]
196
    structures: list[LegacyStructuralTag]
197
198
199
    triggers: list[str]


200
201
202
203
204
205
206
207
208
209
class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    format: Any


AnyStructuralTagResponseFormat: TypeAlias = (
    LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
)


210
class ResponseFormat(OpenAIBaseModel):
211
    # type must be "json_schema", "json_object", or "text"
212
    type: Literal["text", "json_object", "json_schema"]
213
    json_schema: JsonSchemaResponseFormat | None = None
214
215


216
217
218
AnyResponseFormat: TypeAlias = (
    ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
)
219
220


221
class StreamOptions(OpenAIBaseModel):
222
223
    include_usage: bool | None = True
    continuous_usage_stats: bool | None = False
224
225


226
227
class FunctionDefinition(OpenAIBaseModel):
    name: str
228
229
    description: str | None = None
    parameters: dict[str, Any] | None = None
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245


class ChatCompletionToolsParam(OpenAIBaseModel):
    type: Literal["function"] = "function"
    function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
    name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
    function: ChatCompletionNamedFunction
    type: Literal["function"] = "function"


246
247
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
248
249
class LogitsProcessorConstructor(BaseModel):
    qualname: str
250
251
    args: list[Any] | None = None
    kwargs: dict[str, Any] | None = None
252

253
254
    model_config = ConfigDict(extra="forbid")

255

256
LogitsProcessors = list[str | LogitsProcessorConstructor]
257
258


259
def get_logits_processors(
260
261
    processors: LogitsProcessors | None, pattern: str | None
) -> list[Any] | None:
262
263
264
    if processors and pattern:
        logits_processors = []
        for processor in processors:
265
            qualname = processor if isinstance(processor, str) else processor.qualname
266
267
268
269
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
270
271
                    "for more information."
                )
272
273
274
275
276
277
278
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
279
280
281
                logits_processor = logits_processor(
                    *processor.args or [], **processor.kwargs or {}
                )
282
283
284
285
286
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
287
            "server. See --logits-processor-pattern engine argument "
288
289
            "for more information."
        )
290
291
292
    return None


293
ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem
294
295


296
297
298
class ResponsesRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/responses/create
299
300
    background: bool | None = False
    include: (
301
302
303
304
305
306
307
308
309
310
        list[
            Literal[
                "code_interpreter_call.outputs",
                "computer_call_output.output.image_url",
                "file_search_call.results",
                "message.input_image.image_url",
                "message.output_text.logprobs",
                "reasoning.encrypted_content",
            ],
        ]
311
312
313
314
315
316
317
318
319
320
321
322
        | None
    ) = None
    input: str | list[ResponseInputOutputItem]
    instructions: str | None = None
    max_output_tokens: int | None = None
    max_tool_calls: int | None = None
    metadata: Metadata | None = None
    model: str | None = None
    parallel_tool_calls: bool | None = True
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
323
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
324
325
326
327
    store: bool | None = True
    stream: bool | None = False
    temperature: float | None = None
    text: ResponseTextConfig | None = None
328
329
    tool_choice: ToolChoice = "auto"
    tools: list[Tool] = Field(default_factory=list)
330
331
332
333
    top_logprobs: int | None = 0
    top_p: float | None = None
    truncation: Literal["auto", "disabled"] | None = "disabled"
    user: str | None = None
334
335
336
337
338
339
340

    # --8<-- [start:responses-extra-params]
    request_id: str = Field(
        default_factory=lambda: f"resp_{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
341
342
            "through out the inference process and return in response."
        ),
343
    )
344
    mm_processor_kwargs: dict[str, Any] | None = Field(
345
346
347
348
349
350
351
352
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
353
354
            "if the served model does not use priority scheduling."
        ),
355
    )
356
    cache_salt: str | None = Field(
357
358
359
360
361
362
363
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
364
            "to 256 bit)."
365
366
        ),
    )
367
368
369
370
371

    enable_response_messages: bool = Field(
        default=False,
        description=(
            "Dictates whether or not to return messages as part of the "
372
            "response object. Currently only supported for"
373
374
375
            "non-background and gpt-oss only. "
        ),
    )
376
377
378
379
380
    # similar to input_messages / output_messages in ResponsesResponse
    # we take in previous_input_messages (ie in harmony format)
    # this cannot be used in conjunction with previous_response_id
    # TODO: consider supporting non harmony messages as well
    previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
381
382
383
384
385
386
387
388
389
390
    # --8<-- [end:responses-extra-params]

    _DEFAULT_SAMPLING_PARAMS = {
        "temperature": 1.0,
        "top_p": 1.0,
    }

    def to_sampling_params(
        self,
        default_max_tokens: int,
391
        default_sampling_params: dict | None = None,
392
393
394
395
396
397
398
399
400
    ) -> SamplingParams:
        if self.max_output_tokens is None:
            max_tokens = default_max_tokens
        else:
            max_tokens = min(self.max_output_tokens, default_max_tokens)

        default_sampling_params = default_sampling_params or {}
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
401
402
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
403
404
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
405
406
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
407
        stop_token_ids = default_sampling_params.get("stop_token_ids")
408
409

        # Structured output
410
        structured_outputs = None
411
412
        if self.text is not None and self.text.format is not None:
            response_format = self.text.format
413
414
415
416
            if (
                response_format.type == "json_schema"
                and response_format.schema_ is not None
            ):
417
                structured_outputs = StructuredOutputsParams(
418
419
                    json=response_format.schema_
                )
420
421
422
423
424
425
426
427
            elif response_format.type == "json_object":
                raise NotImplementedError("json_object is not supported")

        # TODO: add more parameters
        return SamplingParams.from_optional(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
428
            logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
429
            stop_token_ids=stop_token_ids,
430
431
432
            output_kind=(
                RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
            ),
433
            structured_outputs=structured_outputs,
434
435
        )

436
437
438
439
    def is_include_output_logprobs(self) -> bool:
        """Check if the request includes output logprobs."""
        if self.include is None:
            return False
440
441
442
443
        return (
            isinstance(self.include, list)
            and "message.output_text.logprobs" in self.include
        )
444

445
446
447
448
449
    @model_validator(mode="before")
    def validate_background(cls, data):
        if not data.get("background"):
            return data
        if not data.get("store", True):
450
            raise ValueError("background can only be used when `store` is true")
451
452
453
454
455
456
457
458
        return data

    @model_validator(mode="before")
    def validate_prompt(cls, data):
        if data.get("prompt") is not None:
            raise ValueError("prompt template is not supported")
        return data

459
460
    @model_validator(mode="before")
    def check_cache_salt_support(cls, data):
461
462
463
464
465
466
        if data.get("cache_salt") is not None and (
            not isinstance(data["cache_salt"], str) or not data["cache_salt"]
        ):
            raise ValueError(
                "Parameter 'cache_salt' must be a non-empty string if provided."
            )
467
468
        return data

469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    @model_validator(mode="before")
    def function_call_parsing(cls, data):
        """Parse function_call dictionaries into ResponseFunctionToolCall objects.
        This ensures Pydantic can properly resolve union types in the input field.
        Function calls provided as dicts are converted to ResponseFunctionToolCall
        objects before validation, while invalid structures are left for Pydantic
        to reject with appropriate error messages.
        """

        input_data = data.get("input")

        # Early return for None, strings, or bytes
        # (strings are iterable but shouldn't be processed)
        if input_data is None or isinstance(input_data, (str, bytes)):
            return data

        # Convert iterators (like ValidatorIterator) to list
        if not isinstance(input_data, list):
            try:
                input_data = list(input_data)
            except TypeError:
                # Not iterable, leave as-is for Pydantic to handle
                return data

        processed_input = []
        for item in input_data:
            if isinstance(item, dict) and item.get("type") == "function_call":
                try:
                    processed_input.append(ResponseFunctionToolCall(**item))
                except ValidationError:
                    # Let Pydantic handle validation for malformed function calls
                    logger.debug(
                        "Failed to parse function_call to ResponseFunctionToolCall, "
                        "leaving for Pydantic validation"
                    )
                    processed_input.append(item)
            else:
                processed_input.append(item)

        data["input"] = processed_input
        return data

511

512
class ChatCompletionRequest(OpenAIBaseModel):
513
514
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
515
    messages: list[ChatCompletionMessageParam]
516
517
518
519
520
521
    model: str | None = None
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: bool | None = False
    top_logprobs: int | None = 0
    max_tokens: int | None = Field(
522
        default=None,
523
524
        deprecated="max_tokens is deprecated in favor of "
        "the max_completion_tokens field",
525
    )
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
    max_completion_tokens: int | None = None
    n: int | None = 1
    presence_penalty: float | None = 0.0
    response_format: AnyResponseFormat | None = None
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    temperature: float | None = None
    top_p: float | None = None
    tools: list[ChatCompletionToolsParam] | None = None
    tool_choice: (
        Literal["none"]
        | Literal["auto"]
        | Literal["required"]
        | ChatCompletionNamedToolChoiceParam
        | None
    ) = "none"
    reasoning_effort: Literal["low", "medium", "high"] | None = None
545
    include_reasoning: bool = True
546
    parallel_tool_calls: bool | None = True
547

548
    # NOTE this will be ignored by vLLM
549
    user: str | None = None
550

551
    # --8<-- [start:chat-completion-sampling-params]
552
    use_beam_search: bool = False
553
554
555
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
556
    length_penalty: float = 1.0
557
    stop_token_ids: list[int] | None = []
558
559
560
561
562
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
563
564
565
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    prompt_logprobs: int | None = None
    allowed_token_ids: list[int] | None = None
566
    bad_words: list[str] = Field(default_factory=list)
567
    # --8<-- [end:chat-completion-sampling-params]
568

569
    # --8<-- [start:chat-completion-extra-params]
570
    echo: bool = Field(
571
572
573
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
574
575
            "if they belong to the same role."
        ),
576
    )
577
    add_generation_prompt: bool = Field(
578
        default=True,
579
580
581
582
583
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
584
    )
585
586
    continue_final_message: bool = Field(
        default=False,
587
588
589
590
591
592
593
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
594
    )
595
    add_special_tokens: bool = Field(
596
597
598
599
600
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
601
            "special tokens so this should be set to false (as is the "
602
603
            "default)."
        ),
604
    )
605
    documents: list[dict[str, str]] | None = Field(
606
        default=None,
607
608
609
610
611
612
613
        description=(
            "A list of dicts representing documents that will be accessible to "
            "the model if it is performing RAG (retrieval-augmented generation)."
            " If the template does not support RAG, this argument will have no "
            "effect. We recommend that each document should be a dict containing "
            '"title" and "text" keys.'
        ),
614
    )
615
    chat_template: str | None = Field(
616
617
618
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
619
620
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
621
622
            "does not define one."
        ),
623
    )
624
    chat_template_kwargs: dict[str, Any] | None = Field(
625
        default=None,
626
627
        description=(
            "Additional keyword args to pass to the template renderer. "
628
629
            "Will be accessible by the chat template."
        ),
630
    )
631
    mm_processor_kwargs: dict[str, Any] | None = Field(
632
633
634
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
635
    structured_outputs: StructuredOutputsParams | None = Field(
636
        default=None,
637
        description="Additional kwargs for structured outputs",
638
    )
639
640
641
642
643
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
644
645
            "if the served model does not use priority scheduling."
        ),
646
    )
647
    request_id: str = Field(
648
        default_factory=random_uuid,
649
650
651
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
652
653
            "through out the inference process and return in response."
        ),
654
    )
655
    logits_processors: LogitsProcessors | None = Field(
656
657
658
659
660
661
662
663
664
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
665
666
667
            "{'param': 'value'}}."
        ),
    )
668
    return_tokens_as_token_ids: bool | None = Field(
669
670
671
672
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
673
674
675
            "that are not JSON-encodable can be identified."
        ),
    )
676
    return_token_ids: bool | None = Field(
677
678
679
680
681
682
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
683
684
685
            "need to map generated text back to input tokens."
        ),
    )
686
    cache_salt: str | None = Field(
687
688
689
690
691
692
693
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
694
            "to 256 bit)."
695
696
        ),
    )
697
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
698
        default=None,
699
700
        description="KVTransfer parameters used for disaggregated serving.",
    )
701

702
    vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
703
        default=None,
704
        description=(
705
            "Additional request parameters with (list of) string or "
706
707
            "numeric values, used by custom extensions."
        ),
708
709
    )

710
    # --8<-- [end:chat-completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
711

712
713
714
715
716
    # Default sampling parameters for chat completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
717
        "top_k": 0,
718
719
720
721
        "min_p": 0.0,
    }

    def to_beam_search_params(
722
723
        self, max_tokens: int, default_sampling_params: dict
    ) -> BeamSearchParams:
724
        n = self.n if self.n is not None else 1
725
726
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
727
728
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
729
730
731
732
733
734

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
735
            length_penalty=self.length_penalty,
736
737
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
738

739
    def to_sampling_params(
740
        self,
741
        max_tokens: int,
742
        logits_processor_pattern: str | None,
743
        default_sampling_params: dict,
744
    ) -> SamplingParams:
745
746
747
748
749
750
751
752
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
753
754
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
755
756
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
757
758
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
759
760
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
761
762
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
763
764
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
765
766
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
767

768
769
770
771
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.top_logprobs

772
        response_format = self.response_format
773
        if response_format is not None:
774
775
776
777
778
779
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
            if response_format.type == "json_object":
                self.structured_outputs.json_object = True
            elif response_format.type == "json_schema":
                json_schema = response_format.json_schema
                assert json_schema is not None
                self.structured_outputs.json = json_schema.json_schema
            elif response_format.type == "structural_tag":
                structural_tag = response_format
                assert structural_tag is not None and isinstance(
                    structural_tag,
                    (
                        LegacyStructuralTagResponseFormat,
                        StructuralTagResponseFormat,
                    ),
                )
                s_tag_obj = structural_tag.model_dump(by_alias=True)
                self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
797

798
799
800
801
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
802
        return SamplingParams.from_optional(
803
804
805
            n=self.n,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
806
807
808
809
810
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
811
            seed=self.seed,
812
813
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
814
            logprobs=self.top_logprobs if self.logprobs else None,
815
            prompt_logprobs=prompt_logprobs,
816
            ignore_eos=self.ignore_eos,
817
            max_tokens=max_tokens,
818
            min_tokens=self.min_tokens,
819
820
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
821
822
823
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
824
            include_stop_str_in_output=self.include_stop_str_in_output,
825
            truncate_prompt_tokens=self.truncate_prompt_tokens,
826
827
828
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
829
            structured_outputs=self.structured_outputs,
Robert Shaw's avatar
Robert Shaw committed
830
            logit_bias=self.logit_bias,
831
            bad_words=self.bad_words,
832
            allowed_token_ids=self.allowed_token_ids,
833
834
            extra_args=extra_args or None,
        )
835

836
    @model_validator(mode="before")
837
    @classmethod
838
839
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
840
            raise ValueError("Stream options can only be defined when `stream=True`.")
841
842
843
844
845
846
847

        return data

    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
848
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
849
                raise ValueError(
850
851
                    "`prompt_logprobs` are not available when `stream=True`."
                )
852

853
            if prompt_logprobs < 0 and prompt_logprobs != -1:
854
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
855
        if (top_logprobs := data.get("top_logprobs")) is not None:
856
            if top_logprobs < 0 and top_logprobs != -1:
857
                raise ValueError("`top_logprobs` must be a positive value or -1.")
858

859
            if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
860
861
862
863
864
                raise ValueError(
                    "when using `top_logprobs`, `logprobs` must be set to true."
                )

        return data
865

866
867
    @model_validator(mode="before")
    @classmethod
868
    def check_structured_outputs_count(cls, data):
869
870
871
        if isinstance(data, ValueError):
            raise data

872
        if data.get("structured_outputs", None) is None:
873
874
            return data

875
        structured_outputs_kwargs = data["structured_outputs"]
876
877
        count = sum(
            structured_outputs_kwargs.get(k) is not None
878
879
            for k in ("json", "regex", "choice")
        )
880
881
        # you can only use one kind of constraints for structured outputs
        if count > 1:
882
            raise ValueError(
883
                "You can only use one kind of constraints for structured "
884
885
                "outputs ('json', 'regex' or 'choice')."
            )
886
887
        # you can only either use structured outputs or tools, not both
        if count > 1 and data.get("tool_choice", "none") not in (
888
889
890
            "none",
            "auto",
            "required",
891
        ):
892
            raise ValueError(
893
                "You can only either use constraints for structured outputs "
894
895
                "or tools, not both."
            )
896
897
898
899
        return data

    @model_validator(mode="before")
    @classmethod
900
901
902
    def check_tool_usage(cls, data):
        # if "tool_choice" is not specified but tools are provided,
        # default to "auto" tool_choice
903
        if "tool_choice" not in data and data.get("tools"):
904
905
            data["tool_choice"] = "auto"

906
        # if "tool_choice" is "none" -- no validation is needed for tools
907
908
909
        if "tool_choice" in data and data["tool_choice"] == "none":
            return data

910
        # if "tool_choice" is specified -- validation
911
        if "tool_choice" in data and data["tool_choice"] is not None:
912
            # ensure that if "tool choice" is specified, tools are present
913
            if "tools" not in data or data["tools"] is None:
914
                raise ValueError("When using `tool_choice`, `tools` must be set.")
915
916

            # make sure that tool choice is either a named tool
917
            # OR that it's set to "auto" or "required"
918
919
920
            if data["tool_choice"] not in ["auto", "required"] and not isinstance(
                data["tool_choice"], dict
            ):
921
                raise ValueError(
922
923
924
                    f"Invalid value for `tool_choice`: {data['tool_choice']}! "
                    'Only named tools, "none", "auto" or "required" '
                    "are supported."
925
                )
926

927
928
929
            # if tool_choice is "required" but the "tools" list is empty,
            # override the data to behave like "none" to align with
            # OpenAI’s behavior.
930
931
932
933
934
            if (
                data["tool_choice"] == "required"
                and isinstance(data["tools"], list)
                and len(data["tools"]) == 0
            ):
935
936
937
938
                data["tool_choice"] = "none"
                del data["tools"]
                return data

939
940
            # ensure that if "tool_choice" is specified as an object,
            # it matches a valid tool
941
942
            correct_usage_message = (
                'Correct usage: `{"type": "function",'
943
                ' "function": {"name": "my_function"}}`'
944
            )
945
946
            if isinstance(data["tool_choice"], dict):
                valid_tool = False
947
948
                function = data["tool_choice"].get("function")
                if not isinstance(function, dict):
949
                    raise ValueError(
950
                        f"Invalid value for `function`: `{function}` in "
951
952
                        f"`tool_choice`! {correct_usage_message}"
                    )
953
                if "name" not in function:
954
955
956
957
                    raise ValueError(
                        f"Expected field `name` in `function` in "
                        f"`tool_choice`! {correct_usage_message}"
                    )
958
                function_name = function["name"]
959
                if not isinstance(function_name, str) or len(function_name) == 0:
960
                    raise ValueError(
961
                        f"Invalid `name` in `function`: `{function_name}`"
962
963
                        f" in `tool_choice`! {correct_usage_message}"
                    )
964
                for tool in data["tools"]:
965
                    if tool["function"]["name"] == function_name:
966
967
968
969
970
                        valid_tool = True
                        break
                if not valid_tool:
                    raise ValueError(
                        "The tool specified in `tool_choice` does not match any"
971
972
                        " of the specified `tools`"
                    )
973
974
        return data

975
976
977
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
978
979
980
981
982
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
983
984
        return data

985
986
987
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
988
989
990
991
992
993
        if data.get("cache_salt") is not None and (
            not isinstance(data["cache_salt"], str) or not data["cache_salt"]
        ):
            raise ValueError(
                "Parameter 'cache_salt' must be a non-empty string if provided."
            )
994
995
        return data

Zhuohan Li's avatar
Zhuohan Li committed
996

997
class CompletionRequest(OpenAIBaseModel):
998
999
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
1000
1001
1002
1003
1004
1005
1006
    model: str | None = None
    prompt: list[int] | list[list[int]] | str | list[str] | None = None
    echo: bool | None = False
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: int | None = None
    max_tokens: int | None = 16
1007
    n: int = 1
1008
1009
1010
1011
1012
1013
1014
1015
1016
    presence_penalty: float | None = 0.0
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    suffix: str | None = None
    temperature: float | None = None
    top_p: float | None = None
    user: str | None = None
1017

1018
    # --8<-- [start:completion-sampling-params]
1019
    use_beam_search: bool = False
1020
1021
1022
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
1023
    length_penalty: float = 1.0
1024
    stop_token_ids: list[int] | None = []
1025
1026
1027
1028
1029
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
1030
1031
1032
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    allowed_token_ids: list[int] | None = None
    prompt_logprobs: int | None = None
1033
    # --8<-- [end:completion-sampling-params]
1034

1035
    # --8<-- [start:completion-extra-params]
1036
    prompt_embeds: bytes | list[bytes] | None = None
1037
1038
    add_special_tokens: bool = Field(
        default=True,
1039
        description=(
1040
            "If true (the default), special tokens (e.g. BOS) will be added to "
1041
1042
            "the prompt."
        ),
1043
    )
1044
    response_format: AnyResponseFormat | None = Field(
1045
        default=None,
1046
1047
1048
1049
1050
        description=(
            "Similar to chat completion, this parameter specifies the format "
            "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
            ", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
        ),
1051
    )
1052
    structured_outputs: StructuredOutputsParams | None = Field(
1053
        default=None,
1054
        description="Additional kwargs for structured outputs",
1055
    )
1056
1057
1058
1059
1060
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1061
1062
            "if the served model does not use priority scheduling."
        ),
1063
    )
1064
    request_id: str = Field(
1065
        default_factory=random_uuid,
1066
1067
1068
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1069
1070
            "through out the inference process and return in response."
        ),
1071
    )
1072
    logits_processors: LogitsProcessors | None = Field(
1073
1074
1075
1076
1077
1078
1079
1080
1081
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
1082
1083
1084
            "{'param': 'value'}}."
        ),
    )
1085

1086
    return_tokens_as_token_ids: bool | None = Field(
1087
1088
1089
1090
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
1091
1092
1093
            "that are not JSON-encodable can be identified."
        ),
    )
1094
    return_token_ids: bool | None = Field(
1095
1096
1097
1098
1099
1100
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
1101
1102
1103
            "need to map generated text back to input tokens."
        ),
    )
1104

1105
    cache_salt: str | None = Field(
1106
1107
1108
1109
1110
1111
1112
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
1113
            "to 256 bit)."
1114
1115
        ),
    )
1116

1117
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
1118
        default=None,
1119
1120
        description="KVTransfer parameters used for disaggregated serving.",
    )
Robert Shaw's avatar
Robert Shaw committed
1121

1122
    vllm_xargs: dict[str, str | int | float] | None = Field(
1123
        default=None,
1124
1125
1126
1127
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
1128
1129
    )

1130
    # --8<-- [end:completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
1131

1132
1133
1134
1135
1136
    # Default sampling parameters for completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
1137
        "top_k": 0,
1138
1139
1140
1141
        "min_p": 0.0,
    }

    def to_beam_search_params(
1142
1143
        self,
        max_tokens: int,
1144
        default_sampling_params: dict | None = None,
1145
1146
1147
    ) -> BeamSearchParams:
        if default_sampling_params is None:
            default_sampling_params = {}
1148
        n = self.n if self.n is not None else 1
1149
1150
1151

        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get("temperature", 1.0)
1152
1153
1154
1155
1156
1157

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
1158
            length_penalty=self.length_penalty,
1159
1160
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
1161

1162
    def to_sampling_params(
1163
        self,
1164
        max_tokens: int,
1165
1166
        logits_processor_pattern: str | None,
        default_sampling_params: dict | None = None,
1167
    ) -> SamplingParams:
1168
1169
        if default_sampling_params is None:
            default_sampling_params = {}
1170

1171
1172
1173
1174
1175
1176
1177
1178
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
1179
1180
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
1181
1182
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
1183
1184
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
1185
1186
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
1187
1188
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
1189
1190
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
1191
1192
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
1193

1194
1195
1196
1197
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

1198
1199
        echo_without_generation = self.echo and self.max_tokens == 0

1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
        response_format = self.response_format
        if response_format is not None:
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
            if response_format.type == "json_object":
                self.structured_outputs.json_object = True
            elif response_format.type == "json_schema":
                json_schema = response_format.json_schema
1212
                assert json_schema is not None
1213
1214
1215
                self.structured_outputs.json = json_schema.json_schema
            elif response_format.type == "structural_tag":
                structural_tag = response_format
1216
                assert structural_tag is not None and isinstance(
1217
1218
1219
1220
1221
                    structural_tag,
                    (
                        LegacyStructuralTagResponseFormat,
                        StructuralTagResponseFormat,
                    ),
1222
1223
                )
                s_tag_obj = structural_tag.model_dump(by_alias=True)
1224
                self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
1225

1226
1227
1228
1229
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
1230
        return SamplingParams.from_optional(
1231
1232
1233
            n=self.n,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
1234
1235
1236
1237
1238
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
1239
            seed=self.seed,
1240
1241
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
1242
            logprobs=self.logprobs,
1243
            ignore_eos=self.ignore_eos,
1244
            max_tokens=max_tokens if not echo_without_generation else 1,
1245
            min_tokens=self.min_tokens,
1246
            prompt_logprobs=prompt_logprobs,
1247
            skip_special_tokens=self.skip_special_tokens,
1248
            spaces_between_special_tokens=self.spaces_between_special_tokens,
1249
            include_stop_str_in_output=self.include_stop_str_in_output,
1250
1251
1252
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
1253
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1254
1255
1256
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
1257
            structured_outputs=self.structured_outputs,
1258
            logit_bias=self.logit_bias,
Robert Shaw's avatar
Robert Shaw committed
1259
            allowed_token_ids=self.allowed_token_ids,
1260
            extra_args=extra_args or None,
1261
        )
1262

1263
1264
    @model_validator(mode="before")
    @classmethod
1265
    def check_structured_outputs_count(cls, data):
1266
        if data.get("structured_outputs", None) is None:
1267
1268
            return data

1269
        structured_outputs_kwargs = data["structured_outputs"]
1270
1271
        count = sum(
            structured_outputs_kwargs.get(k) is not None
1272
1273
            for k in ("json", "regex", "choice")
        )
1274
        if count > 1:
1275
            raise ValueError(
1276
                "You can only use one kind of constraints for structured "
1277
1278
                "outputs ('json', 'regex' or 'choice')."
            )
1279
1280
        return data

1281
1282
1283
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
1284
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
1285
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
1286
                raise ValueError(
1287
1288
                    "`prompt_logprobs` are not available when `stream=True`."
                )
1289

1290
            if prompt_logprobs < 0 and prompt_logprobs != -1:
1291
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
1292
1293
1294
        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
            raise ValueError("`logprobs` must be a positive value.")

1295
1296
        return data

1297
1298
1299
1300
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
1301
            raise ValueError("Stream options can only be defined when `stream=True`.")
1302

1303
1304
        return data

1305
1306
1307
    @model_validator(mode="before")
    @classmethod
    def validate_prompt_and_prompt_embeds(cls, data):
1308
1309
1310
        prompt = data.get("prompt")
        prompt_embeds = data.get("prompt_embeds")

1311
1312
1313
1314
        prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
        embeds_is_empty = prompt_embeds is None or (
            isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
        )
1315
1316

        if prompt_is_empty and embeds_is_empty:
1317
            raise ValueError(
1318
1319
1320
                "Either prompt or prompt_embeds must be provided and non-empty."
            )

1321
1322
        return data

1323
1324
1325
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
1326
1327
1328
1329
1330
1331
        if data.get("cache_salt") is not None and (
            not isinstance(data["cache_salt"], str) or not data["cache_salt"]
        ):
            raise ValueError(
                "Parameter 'cache_salt' must be a non-empty string if provided."
            )
1332
1333
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1334

1335
class CompletionLogProbs(OpenAIBaseModel):
1336
    text_offset: list[int] = Field(default_factory=list)
1337
    token_logprobs: list[float | None] = Field(default_factory=list)
1338
    tokens: list[str] = Field(default_factory=list)
1339
    top_logprobs: list[dict[str, float] | None] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
1340
1341


1342
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1343
1344
    index: int
    text: str
1345
1346
1347
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1348
1349
1350
1351
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1352
1353
            "including encountering the EOS token"
        ),
1354
    )
1355
1356
1357
    token_ids: list[int] | None = None  # For response
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None  # For prompt
Zhuohan Li's avatar
Zhuohan Li committed
1358
1359


1360
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1361
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
1362
    object: Literal["text_completion"] = "text_completion"
Zhuohan Li's avatar
Zhuohan Li committed
1363
1364
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1365
    choices: list[CompletionResponseChoice]
1366
1367
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1368
    usage: UsageInfo
1369
1370

    # vLLM-specific fields that are not in OpenAI spec
1371
    kv_transfer_params: dict[str, Any] | None = Field(
1372
1373
        default=None, description="KVTransfer parameters."
    )
Zhuohan Li's avatar
Zhuohan Li committed
1374
1375


1376
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1377
1378
    index: int
    text: str
1379
1380
1381
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1382
1383
1384
1385
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1386
1387
            "including encountering the EOS token"
        ),
1388
    )
1389
1390
    # not part of the OpenAI spec but for tracing the tokens
    # prompt tokens is put into choice to align with CompletionResponseChoice
1391
1392
    prompt_token_ids: list[int] | None = None
    token_ids: list[int] | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1393
1394


1395
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1396
1397
1398
1399
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1400
    choices: list[CompletionResponseStreamChoice]
1401
    usage: UsageInfo | None = Field(default=None)
1402
1403


1404
1405
1406
1407
1408
1409
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
1410
    id: str = Field(default_factory=make_tool_call_id)
1411
1412
1413
1414
    type: Literal["function"] = "function"
    function: FunctionCall


1415
class DeltaFunctionCall(BaseModel):
1416
1417
    name: str | None = None
    arguments: str | None = None
1418
1419
1420
1421


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
1422
1423
    id: str | None = None
    type: Literal["function"] | None = None
1424
    index: int
1425
    function: DeltaFunctionCall | None = None
1426
1427
1428
1429
1430
1431
1432


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
1433
    tool_calls: list[ToolCall]
1434
1435
1436

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
1437
    content: str | None = None
1438
1439


1440
class ChatMessage(OpenAIBaseModel):
1441
    role: str
1442
1443
1444
1445
1446
    content: str | None = None
    refusal: str | None = None
    annotations: OpenAIAnnotation | None = None
    audio: OpenAIChatCompletionAudio | None = None
    function_call: FunctionCall | None = None
1447
    tool_calls: list[ToolCall] = Field(default_factory=list)
1448

1449
    # vLLM-specific fields that are not in OpenAI spec
1450
    reasoning: str | None = None
1451
    reasoning_content: str | None = None
1452
1453
1454
1455
1456
1457
1458
    """Deprecated: use `reasoning` instead."""

    @model_validator(mode="after")
    def handle_deprecated_reasoning_content(self):
        """Copy reasoning to reasoning_content for backward compatibility."""
        self.reasoning_content = self.reasoning
        return self
1459

1460

1461
1462
1463
class ChatCompletionLogProb(OpenAIBaseModel):
    token: str
    logprob: float = -9999.0
1464
    bytes: list[int] | None = None
1465
1466
1467


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
1468
1469
    # Workaround: redefine fields name cache so that it's not
    # shared with the super class.
1470
    field_names: ClassVar[set[str] | None] = None
1471
    top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
1472
1473
1474


class ChatCompletionLogProbs(OpenAIBaseModel):
1475
    content: list[ChatCompletionLogProbsContent] | None = None
1476
1477


1478
class ChatCompletionResponseChoice(OpenAIBaseModel):
1479
1480
    index: int
    message: ChatMessage
1481
    logprobs: ChatCompletionLogProbs | None = None
1482
    # per OpenAI spec this is the default
1483
    finish_reason: str | None = "stop"
1484
    # not part of the OpenAI spec but included in vLLM for legacy reasons
1485
    stop_reason: int | str | None = None
1486
1487
    # not part of the OpenAI spec but is useful for tracing the tokens
    # in agent scenarios
1488
    token_ids: list[int] | None = None
1489
1490


1491
class ChatCompletionResponse(OpenAIBaseModel):
1492
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1493
    object: Literal["chat.completion"] = "chat.completion"
1494
1495
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1496
    choices: list[ChatCompletionResponseChoice]
1497
1498
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
1499
    usage: UsageInfo
1500
1501

    # vLLM-specific fields that are not in OpenAI spec
1502
1503
1504
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None
    kv_transfer_params: dict[str, Any] | None = Field(
1505
1506
        default=None, description="KVTransfer parameters."
    )
1507
1508


1509
class DeltaMessage(OpenAIBaseModel):
1510
1511
    role: str | None = None
    content: str | None = None
1512
    reasoning: str | None = None
1513
    reasoning_content: str | None = None
1514
    """Deprecated: use `reasoning` instead."""
1515
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
1516

1517
1518
1519
1520
1521
1522
    @model_validator(mode="after")
    def handle_deprecated_reasoning_content(self):
        """Copy reasoning to reasoning_content for backward compatibility."""
        self.reasoning_content = self.reasoning
        return self

1523

1524
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
1525
1526
    index: int
    delta: DeltaMessage
1527
1528
1529
    logprobs: ChatCompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = None
1530
    # not part of the OpenAI spec but for tracing the tokens
1531
    token_ids: list[int] | None = None
1532
1533


1534
class ChatCompletionStreamResponse(OpenAIBaseModel):
1535
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1536
    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
1537
1538
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1539
    choices: list[ChatCompletionResponseStreamChoice]
1540
    usage: UsageInfo | None = Field(default=None)
1541
    # not part of the OpenAI spec but for tracing the tokens
1542
    prompt_token_ids: list[int] | None = None
1543
1544


1545
1546
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
1547
1548
    finish_reason: str | None = None
    stop_reason: int | str | None = None
1549
1550
1551
1552
1553
1554
1555
1556


class TranscriptionStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
    object: Literal["transcription.chunk"] = "transcription.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranscriptionResponseStreamChoice]
1557
    usage: UsageInfo | None = Field(default=None)
1558
1559


1560
1561
class InputTokensDetails(OpenAIBaseModel):
    cached_tokens: int
1562
1563
    input_tokens_per_turn: list[int] = Field(default_factory=list)
    cached_tokens_per_turn: list[int] = Field(default_factory=list)
1564
1565
1566


class OutputTokensDetails(OpenAIBaseModel):
1567
1568
    reasoning_tokens: int = 0
    tool_output_tokens: int = 0
1569
1570
    output_tokens_per_turn: list[int] = Field(default_factory=list)
    tool_output_tokens_per_turn: list[int] = Field(default_factory=list)
1571
1572
1573
1574
1575
1576
1577
1578


class ResponseUsage(OpenAIBaseModel):
    input_tokens: int
    input_tokens_details: InputTokensDetails
    output_tokens: int
    output_tokens_details: OutputTokensDetails
    total_tokens: int
1579
1580


1581
1582
1583
1584
1585
1586
def serialize_message(msg):
    """
    Serializes a single message
    """
    if isinstance(msg, dict):
        return msg
1587
    elif hasattr(msg, "to_dict"):
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
        return msg.to_dict()
    else:
        # fallback to pyandic dump
        return msg.model_dump_json()


def serialize_messages(msgs):
    """
    Serializes multiple messages
    """
    return [serialize_message(msg) for msg in msgs] if msgs else None


1601
1602
1603
1604
class ResponsesResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
    created_at: int = Field(default_factory=lambda: int(time.time()))
    # error: Optional[ResponseError] = None
1605
1606
1607
    incomplete_details: IncompleteDetails | None = None
    instructions: str | None = None
    metadata: Metadata | None = None
1608
1609
    model: str
    object: Literal["response"] = "response"
1610
    output: list[ResponseOutputItem]
1611
1612
1613
1614
1615
1616
1617
    parallel_tool_calls: bool
    temperature: float
    tool_choice: ToolChoice
    tools: list[Tool]
    top_p: float
    background: bool
    max_output_tokens: int
1618
1619
1620
1621
    max_tool_calls: int | None = None
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
1622
1623
    service_tier: Literal["auto", "default", "flex", "scale", "priority"]
    status: ResponseStatus
1624
1625
    text: ResponseTextConfig | None = None
    top_logprobs: int | None = None
1626
    truncation: Literal["auto", "disabled"]
1627
1628
    usage: ResponseUsage | None = None
    user: str | None = None
1629

1630
1631
1632
1633
    # --8<-- [start:responses-extra-params]
    # These are populated when enable_response_messages is set to True
    # NOTE: custom serialization is needed
    # see serialize_input_messages and serialize_output_messages
1634
1635
    input_messages: list[ChatCompletionMessageParam] | None = None
    output_messages: list[ChatCompletionMessageParam] | None = None
1636
1637
1638
1639
1640
1641
1642
    # --8<-- [end:responses-extra-params]

    # NOTE: openAI harmony doesn't serialize TextContent properly,
    # TODO: this fixes for TextContent, but need to verify for tools etc
    # https://github.com/openai/harmony/issues/78
    @field_serializer("output_messages", when_used="json")
    def serialize_output_messages(self, msgs, _info):
1643
        return serialize_messages(msgs)
1644
1645
1646
1647
1648

    # NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it
    # https://github.com/openai/harmony/issues/78
    @field_serializer("input_messages", when_used="json")
    def serialize_input_messages(self, msgs, _info):
1649
        return serialize_messages(msgs)
1650

1651
1652
1653
1654
1655
1656
1657
1658
1659
    @classmethod
    def from_request(
        cls,
        request: ResponsesRequest,
        sampling_params: SamplingParams,
        model_name: str,
        created_time: int,
        output: list[ResponseOutputItem],
        status: ResponseStatus,
1660
1661
1662
        usage: ResponseUsage | None = None,
        input_messages: list[ChatCompletionMessageParam] | None = None,
        output_messages: list[ChatCompletionMessageParam] | None = None,
1663
    ) -> "ResponsesResponse":
1664
        incomplete_details: IncompleteDetails | None = None
1665
1666
        if status == "incomplete":
            incomplete_details = IncompleteDetails(reason="max_output_tokens")
1667
1668
1669
        # TODO: implement the other reason for incomplete_details,
        # which is content_filter
        # incomplete_details = IncompleteDetails(reason='content_filter')
1670
1671
1672
        return cls(
            id=request.request_id,
            created_at=created_time,
1673
            incomplete_details=incomplete_details,
1674
1675
1676
1677
            instructions=request.instructions,
            metadata=request.metadata,
            model=model_name,
            output=output,
1678
1679
            input_messages=input_messages,
            output_messages=output_messages,
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
            parallel_tool_calls=request.parallel_tool_calls,
            temperature=sampling_params.temperature,
            tool_choice=request.tool_choice,
            tools=request.tools,
            top_p=sampling_params.top_p,
            background=request.background,
            max_output_tokens=sampling_params.max_tokens,
            max_tool_calls=request.max_tool_calls,
            previous_response_id=request.previous_response_id,
            prompt=request.prompt,
            reasoning=request.reasoning,
            service_tier=request.service_tier,
            status=status,
            text=request.text,
            top_logprobs=sampling_params.logprobs,
            truncation=request.truncation,
            user=request.user,
            usage=usage,
        )


1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartDoneEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.done"]
    """The type of the event. Always `response.reasoning_part.done`."""


# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartAddedEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.added"]
    """The type of the event. Always `response.reasoning_part.added`."""


1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
# vLLM Streaming Events
# Note: we override the response type with the vLLM ResponsesResponse type
class ResponseCompletedEvent(OpenAIResponseCompletedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseCreatedEvent(OpenAIResponseCreatedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseInProgressEvent(OpenAIResponseInProgressEvent):
    response: ResponsesResponse  # type: ignore[override]


1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
StreamingResponsesResponse: TypeAlias = (
    ResponseCreatedEvent
    | ResponseInProgressEvent
    | ResponseCompletedEvent
    | ResponseOutputItemAddedEvent
    | ResponseOutputItemDoneEvent
    | ResponseContentPartAddedEvent
    | ResponseContentPartDoneEvent
    | ResponseReasoningTextDeltaEvent
    | ResponseReasoningTextDoneEvent
    | ResponseReasoningPartAddedEvent
    | ResponseReasoningPartDoneEvent
    | ResponseCodeInterpreterCallInProgressEvent
    | ResponseCodeInterpreterCallCodeDeltaEvent
    | ResponseWebSearchCallInProgressEvent
    | ResponseWebSearchCallSearchingEvent
    | ResponseWebSearchCallCompletedEvent
    | ResponseCodeInterpreterCallCodeDoneEvent
    | ResponseCodeInterpreterCallInterpretingEvent
    | ResponseCodeInterpreterCallCompletedEvent
)
1780

1781

1782
class TokenizeCompletionRequest(OpenAIBaseModel):
1783
    model: str | None = None
1784
1785
    prompt: str

1786
1787
1788
1789
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
1790
1791
            "the prompt."
        ),
1792
    )
1793
    return_token_strs: bool | None = Field(
1794
        default=False,
1795
1796
1797
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
1798
    )
1799
1800
1801


class TokenizeChatRequest(OpenAIBaseModel):
1802
    model: str | None = None
1803
    messages: list[ChatCompletionMessageParam]
1804

1805
1806
    add_generation_prompt: bool = Field(
        default=True,
1807
1808
1809
1810
1811
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
1812
    )
1813
    return_token_strs: bool | None = Field(
1814
        default=False,
1815
1816
1817
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
1818
    )
1819
1820
    continue_final_message: bool = Field(
        default=False,
1821
1822
1823
1824
1825
1826
1827
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
1828
1829
1830
1831
1832
1833
1834
1835
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
1836
1837
            "default)."
        ),
1838
    )
1839
    chat_template: str | None = Field(
1840
1841
1842
1843
1844
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1845
1846
            "does not define one."
        ),
1847
    )
1848
    chat_template_kwargs: dict[str, Any] | None = Field(
1849
        default=None,
1850
1851
        description=(
            "Additional keyword args to pass to the template renderer. "
1852
1853
            "Will be accessible by the chat template."
        ),
1854
    )
1855
    mm_processor_kwargs: dict[str, Any] | None = Field(
1856
1857
1858
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
1859
    tools: list[ChatCompletionToolsParam] | None = Field(
1860
1861
1862
        default=None,
        description=("A list of tools the model may call."),
    )
1863

1864
1865
1866
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1867
1868
1869
1870
1871
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1872
1873
        return data

1874

1875
TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest
1876
1877
1878
1879
1880


class TokenizeResponse(OpenAIBaseModel):
    count: int
    max_model_len: int
1881
    tokens: list[int]
1882
    token_strs: list[str] | None = None
1883
1884
1885


class DetokenizeRequest(OpenAIBaseModel):
1886
    model: str | None = None
1887
    tokens: list[int]
1888
1889
1890
1891


class DetokenizeResponse(OpenAIBaseModel):
    prompt: str
1892
1893


1894
1895
class TokenizerInfoResponse(OpenAIBaseModel):
    """
1896
    Response containing tokenizer configuration
1897
1898
1899
1900
1901
1902
1903
    equivalent to tokenizer_config.json
    """

    model_config = ConfigDict(extra="allow")
    tokenizer_class: str


1904
class LoadLoRAAdapterRequest(BaseModel):
1905
1906
1907
1908
    lora_name: str
    lora_path: str


1909
class UnloadLoRAAdapterRequest(BaseModel):
1910
    lora_name: str
1911
    lora_int_id: int | None = Field(default=None)
1912
1913
1914


## Protocols for Audio
1915
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
1916
1917
1918
1919


class TranscriptionRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
1920
    # https://platform.openai.com/docs/api-reference/audio/createTranscription
1921
1922
1923
1924
1925
1926
1927

    file: UploadFile
    """
    The audio file object (not file name) to transcribe, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

1928
    model: str | None = None
1929
1930
1931
    """ID of the model to use.
    """

1932
    language: str | None = None
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
    """The language of the input audio.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy and latency.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    ## TODO (varun) : Support if set to 0, certain thresholds are met !!

1956
    timestamp_granularities: list[Literal["word", "segment"]] = Field(
1957
1958
        alias="timestamp_granularities[]", default=[]
    )
1959
1960
1961
1962
1963
1964
1965
1966
    """The timestamp granularities to populate for this transcription.

    `response_format` must be set `verbose_json` to use timestamp granularities.
    Either or both of these options are supported: `word`, or `segment`. Note:
    There is no additional latency for segment timestamps, but generating word
    timestamps incurs additional latency.
    """

1967
    stream: bool | None = False
1968
    """When set, it will enable output to be streamed in a similar fashion
1969
    as the Chat Completion endpoint.
1970
    """
1971
    # --8<-- [start:transcription-extra-params]
1972
    # Flattened stream option to simplify form data.
1973
1974
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
1975

1976
    vllm_xargs: dict[str, str | int | float] | None = Field(
1977
        default=None,
1978
1979
1980
1981
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
1982
    )
1983
    # --8<-- [end:transcription-extra-params]
1984

1985
    to_language: str | None = None
1986
1987
    """The language of the output audio we transcribe to.

1988
    Please note that this is not currently used by supported models at this
1989
1990
1991
    time, but it is a placeholder for future use, matching translation api.
    """

1992
    # --8<-- [start:transcription-sampling-params]
1993
1994
1995
1996
1997
1998
1999
2000
2001
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """

2002
    top_p: float | None = None
2003
    """Enables nucleus (top-p) sampling, where tokens are selected from the
2004
2005
2006
    smallest possible set whose cumulative probability exceeds `p`.
    """

2007
    top_k: int | None = None
2008
2009
    """Limits sampling to the `k` most probable tokens at each step."""

2010
    min_p: float | None = None
2011
    """Filters out tokens with a probability lower than `min_p`, ensuring a
2012
2013
2014
    minimum likelihood threshold during sampling.
    """

2015
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2016
2017
    """The seed to use for sampling."""

2018
    frequency_penalty: float | None = 0.0
2019
2020
    """The frequency penalty to use for sampling."""

2021
    repetition_penalty: float | None = None
2022
2023
    """The repetition penalty to use for sampling."""

2024
    presence_penalty: float | None = 0.0
2025
    """The presence penalty to use for sampling."""
2026
    # --8<-- [end:transcription-sampling-params]
2027

2028
2029
    # Default sampling parameters for transcription requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
2030
2031
2032
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
2033
        "top_k": 0,
2034
        "min_p": 0.0,
2035
2036
2037
    }

    def to_sampling_params(
2038
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2039
    ) -> SamplingParams:
2040
2041
2042
2043
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
2044

2045
2046
2047
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2048
2049
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2050
2051
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
2052
2053
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
2054
2055
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
2056
2057
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
2058
2059
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
2060
2061
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
2062
2063
2064
2065

        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )

        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=repetition_penalty,
            presence_penalty=self.presence_penalty,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
            extra_args=self.vllm_xargs,
        )
2084
2085
2086

    @model_validator(mode="before")
    @classmethod
2087
2088
2089
2090
2091
2092
2093
    def validate_transcription_request(cls, data):
        if isinstance(data.get("file"), str):
            raise HTTPException(
                status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
                detail="Expected 'file' to be a file-like object, not 'str'.",
            )

2094
2095
2096
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2097
            raise ValueError("Stream options can only be defined when `stream=True`.")
2098
2099

        return data
2100
2101
2102


# Transcription response objects
2103
2104
2105
2106
2107
class TranscriptionUsageAudio(OpenAIBaseModel):
    type: Literal["duration"] = "duration"
    seconds: int


2108
2109
2110
class TranscriptionResponse(OpenAIBaseModel):
    text: str
    """The transcribed text."""
2111
    usage: TranscriptionUsageAudio
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128


class TranscriptionWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranscriptionSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

2129
    avg_logprob: float | None = None
2130
2131
2132
2133
2134
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

2135
    compression_ratio: float | None = None
2136
2137
2138
2139
2140
2141
2142
2143
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

2144
    no_speech_prob: float | None = None
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

2163
    tokens: list[int]
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
    """Array of token IDs for the text content."""


class TranscriptionResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The transcribed text."""

2177
    segments: list[TranscriptionSegment] | None = None
2178
2179
    """Segments of the transcribed text and their corresponding details."""

2180
    words: list[TranscriptionWord] | None = None
2181
    """Extracted words and their corresponding timestamps."""
2182
2183


2184
2185
2186
2187
2188
TranscriptionResponseVariant: TypeAlias = (
    TranscriptionResponse | TranscriptionResponseVerbose
)


2189
2190
class TranslationResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
2191
2192
    finish_reason: str | None = None
    stop_reason: int | str | None = None
2193
2194
2195
2196
2197
2198
2199
2200


class TranslationStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
    object: Literal["translation.chunk"] = "translation.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranslationResponseStreamChoice]
2201
    usage: UsageInfo | None = Field(default=None)
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213


class TranslationRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/audio/createTranslation

    file: UploadFile
    """
    The audio file object (not file name) to translate, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2214
    model: str | None = None
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
    """ID of the model to use.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    # TODO support additional sampling parameters
    # --8<-- [start:translation-sampling-params]
2234
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2235
2236
    """The seed to use for sampling."""

2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """
    # --8<-- [end:translation-sampling-params]

    # --8<-- [start:translation-extra-params]
2248
    language: str | None = None
2249
2250
2251
2252
2253
2254
2255
    """The language of the input audio we translate from.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy.
    """

2256
    to_language: str | None = None
2257
2258
2259
2260
2261
2262
2263
    """The language of the input audio we translate to.

    Please note that this is not supported by all models, refer to the specific
    model documentation for more details.
    For instance, Whisper only supports `to_language=en`.
    """

2264
    stream: bool | None = False
2265
    """Custom field not present in the original OpenAI definition. When set,
2266
    it will enable output to be streamed in a similar fashion as the Chat
2267
    Completion endpoint.
2268
2269
    """
    # Flattened stream option to simplify form data.
2270
2271
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
2272
2273
2274
2275
2276
2277
2278
2279
    # --8<-- [end:translation-extra-params]

    # Default sampling parameters for translation requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "temperature": 0,
    }

    def to_sampling_params(
2280
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2281
    ) -> SamplingParams:
2282
2283
2284
2285
2286
2287
2288
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2289
2290
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2291

2292
2293
2294
2295
2296
2297
2298
2299
        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
        )
2300
2301
2302
2303
2304
2305
2306

    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2307
            raise ValueError("Stream options can only be defined when `stream=True`.")
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332

        return data


# Translation response objects
class TranslationResponse(OpenAIBaseModel):
    text: str
    """The translated text."""


class TranslationWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranslationSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

2333
    avg_logprob: float | None = None
2334
2335
2336
2337
2338
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

2339
    compression_ratio: float | None = None
2340
2341
2342
2343
2344
2345
2346
2347
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

2348
    no_speech_prob: float | None = None
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

    tokens: list[int]
    """Array of token IDs for the text content."""


class TranslationResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The translated text."""

2381
    segments: list[TranslationSegment] | None = None
2382
2383
    """Segments of the translated text and their corresponding details."""

2384
    words: list[TranslationWord] | None = None
2385
    """Extracted words and their corresponding timestamps."""
2386
2387


2388
2389
2390
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose


2391
2392
2393
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
    request_id: str = Field(
2394
        default_factory=random_uuid,
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    token_ids: list[int]
    """The token ids to generate text from."""

    # features: MultiModalFeatureSpec
    # TODO (NickLucche): implement once Renderer work is completed
    features: str | None = None
    """The processed MM inputs for the model."""

    sampling_params: SamplingParams
    """The sampling parameters for the model."""

    model: str | None = None

    stream: bool | None = False
    stream_options: StreamOptions | None = None
    cache_salt: str | None = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit)."
        ),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."
        ),
    )
    kv_transfer_params: dict[str, Any] | None = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.",
    )


class GenerateResponseChoice(BaseModel):
    index: int
    logprobs: ChatCompletionLogProbs | None = None
    # per OpenAI spec this is the default
    finish_reason: str | None = "stop"
    token_ids: list[int] | None = None


class GenerateResponse(BaseModel):
    request_id: str = Field(
2451
        default_factory=random_uuid,
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    choices: list[GenerateResponseChoice]

    prompt_logprobs: list[dict[int, Logprob] | None] | None = None

    kv_transfer_params: dict[str, Any] | None = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.",
    )