protocol.py 86.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
6
import json
Zhuohan Li's avatar
Zhuohan Li committed
7
import time
8
from http import HTTPStatus
9
from typing import Annotated, Any, ClassVar, Literal, TypeAlias
Zhuohan Li's avatar
Zhuohan Li committed
10

11
import regex as re
12
import torch
13
from fastapi import HTTPException, UploadFile
14
from openai.types.chat.chat_completion_audio import (
15
16
17
    ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
18
19
20
21
22
from openai.types.responses import (
    ResponseCodeInterpreterCallCodeDeltaEvent,
    ResponseCodeInterpreterCallCodeDoneEvent,
    ResponseCodeInterpreterCallCompletedEvent,
    ResponseCodeInterpreterCallInProgressEvent,
23
24
25
26
27
    ResponseCodeInterpreterCallInterpretingEvent,
    ResponseContentPartAddedEvent,
    ResponseContentPartDoneEvent,
    ResponseFunctionToolCall,
    ResponseInputItemParam,
28
29
30
31
    ResponseMcpCallArgumentsDeltaEvent,
    ResponseMcpCallArgumentsDoneEvent,
    ResponseMcpCallCompletedEvent,
    ResponseMcpCallInProgressEvent,
32
33
34
35
36
37
38
39
40
41
42
    ResponseOutputItem,
    ResponseOutputItemAddedEvent,
    ResponseOutputItemDoneEvent,
    ResponsePrompt,
    ResponseReasoningTextDeltaEvent,
    ResponseReasoningTextDoneEvent,
    ResponseStatus,
    ResponseWebSearchCallCompletedEvent,
    ResponseWebSearchCallInProgressEvent,
    ResponseWebSearchCallSearchingEvent,
)
43
from openai.types.responses import (
44
45
46
    ResponseCompletedEvent as OpenAIResponseCompletedEvent,
)
from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent
47
from openai.types.responses import (
48
49
    ResponseInProgressEvent as OpenAIResponseInProgressEvent,
)
50
from openai.types.responses.response_reasoning_item import (
51
52
    Content as ResponseReasoningTextContent,
)
53
from openai_harmony import Message as OpenAIHarmonyMessage
54
55
56
57
58

# Backward compatibility for OpenAI client versions
try:  # For older openai versions (< 1.100.0)
    from openai.types.responses import ResponseTextConfig
except ImportError:  # For newer openai versions (>= 1.100.0)
59
    from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
60

61

62
from openai.types.responses.response import IncompleteDetails, ToolChoice
63
64
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
65
66
67
68
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
69
    ValidationError,
70
    field_serializer,
71
72
    model_validator,
)
Zhuohan Li's avatar
Zhuohan Li committed
73

74
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
75
from vllm.logger import init_logger
76
from vllm.logprobs import Logprob
77
78
79
80
81
82
from vllm.sampling_params import (
    BeamSearchParams,
    RequestOutputKind,
    SamplingParams,
    StructuredOutputsParams,
)
83
84
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
85

86
87
logger = init_logger(__name__)

88
_LONG_INFO = torch.iinfo(torch.long)
89

Zhuohan Li's avatar
Zhuohan Li committed
90

91
class OpenAIBaseModel(BaseModel):
92
93
94
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

95
    # Cache class field names
96
    field_names: ClassVar[set[str] | None] = None
97

98
    @model_validator(mode="wrap")
99
    @classmethod
100
101
102
103
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
104
105
        field_names = cls.field_names
        if field_names is None:
106
107
108
109
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
110
                if alias := getattr(field, "alias", None):
111
112
113
114
115
116
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
117
                "The following fields were present in the request but ignored: %s",
118
119
                data.keys() - field_names,
            )
120
        return result
121
122


123
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
124
125
    message: str
    type: str
126
    param: str | None = None
127
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
128
129


130
131
132
133
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


134
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
135
136
137
138
139
140
141
142
143
144
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
145
    group: str | None = None
146
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
147
148


149
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
150
151
152
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
153
    owned_by: str = "vllm"
154
155
156
    root: str | None = None
    parent: str | None = None
    max_model_len: int | None = None
157
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
158
159


160
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
161
    object: str = "list"
162
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
163
164


165
class PromptTokenUsageInfo(OpenAIBaseModel):
166
    cached_tokens: int | None = None
167
168


169
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
170
171
    prompt_tokens: int = 0
    total_tokens: int = 0
172
173
    completion_tokens: int | None = 0
    prompt_tokens_details: PromptTokenUsageInfo | None = None
Zhuohan Li's avatar
Zhuohan Li committed
174
175


176
177
class RequestResponseMetadata(BaseModel):
    request_id: str
178
    final_usage_info: UsageInfo | None = None
179
180


181
182
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
183
    description: str | None = None
184
185
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
186
187
    json_schema: dict[str, Any] | None = Field(default=None, alias="schema")
    strict: bool | None = None
188
189


190
class LegacyStructuralTag(OpenAIBaseModel):
191
192
193
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
194
    structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema")
195
196
197
    end: str


198
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
199
    type: Literal["structural_tag"]
200
    structures: list[LegacyStructuralTag]
201
202
203
    triggers: list[str]


204
205
206
207
208
209
210
211
212
213
class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    format: Any


AnyStructuralTagResponseFormat: TypeAlias = (
    LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
)


214
class ResponseFormat(OpenAIBaseModel):
215
    # type must be "json_schema", "json_object", or "text"
216
    type: Literal["text", "json_object", "json_schema"]
217
    json_schema: JsonSchemaResponseFormat | None = None
218
219


220
221
222
AnyResponseFormat: TypeAlias = (
    ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
)
223
224


225
class StreamOptions(OpenAIBaseModel):
226
227
    include_usage: bool | None = True
    continuous_usage_stats: bool | None = False
228
229


230
231
class FunctionDefinition(OpenAIBaseModel):
    name: str
232
233
    description: str | None = None
    parameters: dict[str, Any] | None = None
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249


class ChatCompletionToolsParam(OpenAIBaseModel):
    type: Literal["function"] = "function"
    function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
    name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
    function: ChatCompletionNamedFunction
    type: Literal["function"] = "function"


250
251
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
252
253
class LogitsProcessorConstructor(BaseModel):
    qualname: str
254
255
    args: list[Any] | None = None
    kwargs: dict[str, Any] | None = None
256

257
258
    model_config = ConfigDict(extra="forbid")

259

260
LogitsProcessors = list[str | LogitsProcessorConstructor]
261
262


263
def get_logits_processors(
264
265
    processors: LogitsProcessors | None, pattern: str | None
) -> list[Any] | None:
266
267
268
    if processors and pattern:
        logits_processors = []
        for processor in processors:
269
            qualname = processor if isinstance(processor, str) else processor.qualname
270
271
272
273
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
274
275
                    "for more information."
                )
276
277
278
279
280
281
282
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
283
284
285
                logits_processor = logits_processor(
                    *processor.args or [], **processor.kwargs or {}
                )
286
287
288
289
290
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
291
            "server. See --logits-processor-pattern engine argument "
292
293
            "for more information."
        )
294
295
296
    return None


297
ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem
298
299


300
301
302
class ResponsesRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/responses/create
303
304
    background: bool | None = False
    include: (
305
306
307
308
309
310
311
312
313
314
        list[
            Literal[
                "code_interpreter_call.outputs",
                "computer_call_output.output.image_url",
                "file_search_call.results",
                "message.input_image.image_url",
                "message.output_text.logprobs",
                "reasoning.encrypted_content",
            ],
        ]
315
316
317
318
319
320
321
322
323
324
325
326
        | None
    ) = None
    input: str | list[ResponseInputOutputItem]
    instructions: str | None = None
    max_output_tokens: int | None = None
    max_tool_calls: int | None = None
    metadata: Metadata | None = None
    model: str | None = None
    parallel_tool_calls: bool | None = True
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
327
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
328
329
330
331
    store: bool | None = True
    stream: bool | None = False
    temperature: float | None = None
    text: ResponseTextConfig | None = None
332
333
    tool_choice: ToolChoice = "auto"
    tools: list[Tool] = Field(default_factory=list)
334
335
336
337
    top_logprobs: int | None = 0
    top_p: float | None = None
    truncation: Literal["auto", "disabled"] | None = "disabled"
    user: str | None = None
338
339
340
341
342
343
344

    # --8<-- [start:responses-extra-params]
    request_id: str = Field(
        default_factory=lambda: f"resp_{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
345
346
            "through out the inference process and return in response."
        ),
347
    )
348
    mm_processor_kwargs: dict[str, Any] | None = Field(
349
350
351
352
353
354
355
356
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
357
358
            "if the served model does not use priority scheduling."
        ),
359
    )
360
    cache_salt: str | None = Field(
361
362
363
364
365
366
367
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
368
            "to 256 bit)."
369
370
        ),
    )
371
372
373
374
375

    enable_response_messages: bool = Field(
        default=False,
        description=(
            "Dictates whether or not to return messages as part of the "
376
            "response object. Currently only supported for"
377
378
379
            "non-background and gpt-oss only. "
        ),
    )
380
381
382
383
384
    # similar to input_messages / output_messages in ResponsesResponse
    # we take in previous_input_messages (ie in harmony format)
    # this cannot be used in conjunction with previous_response_id
    # TODO: consider supporting non harmony messages as well
    previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
385
386
387
388
389
390
391
392
393
394
    # --8<-- [end:responses-extra-params]

    _DEFAULT_SAMPLING_PARAMS = {
        "temperature": 1.0,
        "top_p": 1.0,
    }

    def to_sampling_params(
        self,
        default_max_tokens: int,
395
        default_sampling_params: dict | None = None,
396
397
398
399
400
401
402
403
404
    ) -> SamplingParams:
        if self.max_output_tokens is None:
            max_tokens = default_max_tokens
        else:
            max_tokens = min(self.max_output_tokens, default_max_tokens)

        default_sampling_params = default_sampling_params or {}
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
405
406
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
407
408
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
409
410
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
411
        stop_token_ids = default_sampling_params.get("stop_token_ids")
412
413

        # Structured output
414
        structured_outputs = None
415
416
        if self.text is not None and self.text.format is not None:
            response_format = self.text.format
417
418
419
420
            if (
                response_format.type == "json_schema"
                and response_format.schema_ is not None
            ):
421
                structured_outputs = StructuredOutputsParams(
422
423
                    json=response_format.schema_
                )
424
425
426
427
428
429
430
431
            elif response_format.type == "json_object":
                raise NotImplementedError("json_object is not supported")

        # TODO: add more parameters
        return SamplingParams.from_optional(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
432
            logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
433
            stop_token_ids=stop_token_ids,
434
435
436
            output_kind=(
                RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
            ),
437
            structured_outputs=structured_outputs,
438
439
        )

440
441
442
443
    def is_include_output_logprobs(self) -> bool:
        """Check if the request includes output logprobs."""
        if self.include is None:
            return False
444
445
446
447
        return (
            isinstance(self.include, list)
            and "message.output_text.logprobs" in self.include
        )
448

449
450
451
452
453
    @model_validator(mode="before")
    def validate_background(cls, data):
        if not data.get("background"):
            return data
        if not data.get("store", True):
454
            raise ValueError("background can only be used when `store` is true")
455
456
457
458
459
460
461
462
        return data

    @model_validator(mode="before")
    def validate_prompt(cls, data):
        if data.get("prompt") is not None:
            raise ValueError("prompt template is not supported")
        return data

463
464
    @model_validator(mode="before")
    def check_cache_salt_support(cls, data):
465
466
467
468
469
470
        if data.get("cache_salt") is not None and (
            not isinstance(data["cache_salt"], str) or not data["cache_salt"]
        ):
            raise ValueError(
                "Parameter 'cache_salt' must be a non-empty string if provided."
            )
471
472
        return data

473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
    @model_validator(mode="before")
    def function_call_parsing(cls, data):
        """Parse function_call dictionaries into ResponseFunctionToolCall objects.
        This ensures Pydantic can properly resolve union types in the input field.
        Function calls provided as dicts are converted to ResponseFunctionToolCall
        objects before validation, while invalid structures are left for Pydantic
        to reject with appropriate error messages.
        """

        input_data = data.get("input")

        # Early return for None, strings, or bytes
        # (strings are iterable but shouldn't be processed)
        if input_data is None or isinstance(input_data, (str, bytes)):
            return data

        # Convert iterators (like ValidatorIterator) to list
        if not isinstance(input_data, list):
            try:
                input_data = list(input_data)
            except TypeError:
                # Not iterable, leave as-is for Pydantic to handle
                return data

        processed_input = []
        for item in input_data:
            if isinstance(item, dict) and item.get("type") == "function_call":
                try:
                    processed_input.append(ResponseFunctionToolCall(**item))
                except ValidationError:
                    # Let Pydantic handle validation for malformed function calls
                    logger.debug(
                        "Failed to parse function_call to ResponseFunctionToolCall, "
                        "leaving for Pydantic validation"
                    )
                    processed_input.append(item)
            else:
                processed_input.append(item)

        data["input"] = processed_input
        return data

515

516
class ChatCompletionRequest(OpenAIBaseModel):
517
518
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
519
    messages: list[ChatCompletionMessageParam]
520
521
522
523
524
525
    model: str | None = None
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: bool | None = False
    top_logprobs: int | None = 0
    max_tokens: int | None = Field(
526
        default=None,
527
528
        deprecated="max_tokens is deprecated in favor of "
        "the max_completion_tokens field",
529
    )
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
    max_completion_tokens: int | None = None
    n: int | None = 1
    presence_penalty: float | None = 0.0
    response_format: AnyResponseFormat | None = None
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    temperature: float | None = None
    top_p: float | None = None
    tools: list[ChatCompletionToolsParam] | None = None
    tool_choice: (
        Literal["none"]
        | Literal["auto"]
        | Literal["required"]
        | ChatCompletionNamedToolChoiceParam
        | None
    ) = "none"
    reasoning_effort: Literal["low", "medium", "high"] | None = None
549
    include_reasoning: bool = True
550
    parallel_tool_calls: bool | None = True
551

552
    # NOTE this will be ignored by vLLM
553
    user: str | None = None
554

555
    # --8<-- [start:chat-completion-sampling-params]
556
    use_beam_search: bool = False
557
558
559
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
560
    length_penalty: float = 1.0
561
    stop_token_ids: list[int] | None = []
562
563
564
565
566
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
567
568
569
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    prompt_logprobs: int | None = None
    allowed_token_ids: list[int] | None = None
570
    bad_words: list[str] = Field(default_factory=list)
571
    # --8<-- [end:chat-completion-sampling-params]
572

573
    # --8<-- [start:chat-completion-extra-params]
574
    echo: bool = Field(
575
576
577
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
578
579
            "if they belong to the same role."
        ),
580
    )
581
    add_generation_prompt: bool = Field(
582
        default=True,
583
584
585
586
587
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
588
    )
589
590
    continue_final_message: bool = Field(
        default=False,
591
592
593
594
595
596
597
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
598
    )
599
    add_special_tokens: bool = Field(
600
601
602
603
604
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
605
            "special tokens so this should be set to false (as is the "
606
607
            "default)."
        ),
608
    )
609
    documents: list[dict[str, str]] | None = Field(
610
        default=None,
611
612
613
614
615
616
617
        description=(
            "A list of dicts representing documents that will be accessible to "
            "the model if it is performing RAG (retrieval-augmented generation)."
            " If the template does not support RAG, this argument will have no "
            "effect. We recommend that each document should be a dict containing "
            '"title" and "text" keys.'
        ),
618
    )
619
    chat_template: str | None = Field(
620
621
622
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
623
624
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
625
626
            "does not define one."
        ),
627
    )
628
    chat_template_kwargs: dict[str, Any] | None = Field(
629
        default=None,
630
631
        description=(
            "Additional keyword args to pass to the template renderer. "
632
633
            "Will be accessible by the chat template."
        ),
634
    )
635
    mm_processor_kwargs: dict[str, Any] | None = Field(
636
637
638
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
639
    structured_outputs: StructuredOutputsParams | None = Field(
640
        default=None,
641
        description="Additional kwargs for structured outputs",
642
    )
643
644
645
646
647
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
648
649
            "if the served model does not use priority scheduling."
        ),
650
    )
651
    request_id: str = Field(
652
        default_factory=random_uuid,
653
654
655
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
656
657
            "through out the inference process and return in response."
        ),
658
    )
659
    logits_processors: LogitsProcessors | None = Field(
660
661
662
663
664
665
666
667
668
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
669
670
671
            "{'param': 'value'}}."
        ),
    )
672
    return_tokens_as_token_ids: bool | None = Field(
673
674
675
676
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
677
678
679
            "that are not JSON-encodable can be identified."
        ),
    )
680
    return_token_ids: bool | None = Field(
681
682
683
684
685
686
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
687
688
689
            "need to map generated text back to input tokens."
        ),
    )
690
    cache_salt: str | None = Field(
691
692
693
694
695
696
697
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
698
            "to 256 bit)."
699
700
        ),
    )
701
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
702
        default=None,
703
704
        description="KVTransfer parameters used for disaggregated serving.",
    )
705

706
    vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
707
        default=None,
708
        description=(
709
            "Additional request parameters with (list of) string or "
710
711
            "numeric values, used by custom extensions."
        ),
712
713
    )

714
    # --8<-- [end:chat-completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
715

716
717
718
719
720
    # Default sampling parameters for chat completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
721
        "top_k": 0,
722
723
724
725
        "min_p": 0.0,
    }

    def to_beam_search_params(
726
727
        self, max_tokens: int, default_sampling_params: dict
    ) -> BeamSearchParams:
728
        n = self.n if self.n is not None else 1
729
730
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
731
732
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
733
734
735
736
737
738

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
739
            length_penalty=self.length_penalty,
740
741
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
742

743
    def to_sampling_params(
744
        self,
745
        max_tokens: int,
746
        logits_processor_pattern: str | None,
747
        default_sampling_params: dict,
748
    ) -> SamplingParams:
749
750
751
752
753
754
755
756
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
757
758
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
759
760
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
761
762
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
763
764
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
765
766
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
767
768
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
769
770
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
771

772
773
774
775
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.top_logprobs

776
        response_format = self.response_format
777
        if response_format is not None:
778
779
780
781
782
783
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
            if response_format.type == "json_object":
                self.structured_outputs.json_object = True
            elif response_format.type == "json_schema":
                json_schema = response_format.json_schema
                assert json_schema is not None
                self.structured_outputs.json = json_schema.json_schema
            elif response_format.type == "structural_tag":
                structural_tag = response_format
                assert structural_tag is not None and isinstance(
                    structural_tag,
                    (
                        LegacyStructuralTagResponseFormat,
                        StructuralTagResponseFormat,
                    ),
                )
                s_tag_obj = structural_tag.model_dump(by_alias=True)
                self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
801

802
803
804
805
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
806
        return SamplingParams.from_optional(
807
808
809
            n=self.n,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
810
811
812
813
814
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
815
            seed=self.seed,
816
817
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
818
            logprobs=self.top_logprobs if self.logprobs else None,
819
            prompt_logprobs=prompt_logprobs,
820
            ignore_eos=self.ignore_eos,
821
            max_tokens=max_tokens,
822
            min_tokens=self.min_tokens,
823
824
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
825
826
827
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
828
            include_stop_str_in_output=self.include_stop_str_in_output,
829
            truncate_prompt_tokens=self.truncate_prompt_tokens,
830
831
832
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
833
            structured_outputs=self.structured_outputs,
Robert Shaw's avatar
Robert Shaw committed
834
            logit_bias=self.logit_bias,
835
            bad_words=self.bad_words,
836
            allowed_token_ids=self.allowed_token_ids,
837
838
            extra_args=extra_args or None,
        )
839

840
    @model_validator(mode="before")
841
    @classmethod
842
843
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
844
            raise ValueError("Stream options can only be defined when `stream=True`.")
845
846
847
848
849
850
851

        return data

    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
852
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
853
                raise ValueError(
854
855
                    "`prompt_logprobs` are not available when `stream=True`."
                )
856

857
            if prompt_logprobs < 0 and prompt_logprobs != -1:
858
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
859
        if (top_logprobs := data.get("top_logprobs")) is not None:
860
            if top_logprobs < 0 and top_logprobs != -1:
861
                raise ValueError("`top_logprobs` must be a positive value or -1.")
862

863
            if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
864
865
866
867
868
                raise ValueError(
                    "when using `top_logprobs`, `logprobs` must be set to true."
                )

        return data
869

870
871
    @model_validator(mode="before")
    @classmethod
872
    def check_structured_outputs_count(cls, data):
873
874
875
        if isinstance(data, ValueError):
            raise data

876
        if data.get("structured_outputs", None) is None:
877
878
            return data

879
        structured_outputs_kwargs = data["structured_outputs"]
880
881
        count = sum(
            structured_outputs_kwargs.get(k) is not None
882
883
            for k in ("json", "regex", "choice")
        )
884
885
        # you can only use one kind of constraints for structured outputs
        if count > 1:
886
            raise ValueError(
887
                "You can only use one kind of constraints for structured "
888
889
                "outputs ('json', 'regex' or 'choice')."
            )
890
891
        # you can only either use structured outputs or tools, not both
        if count > 1 and data.get("tool_choice", "none") not in (
892
893
894
            "none",
            "auto",
            "required",
895
        ):
896
            raise ValueError(
897
                "You can only either use constraints for structured outputs "
898
899
                "or tools, not both."
            )
900
901
902
903
        return data

    @model_validator(mode="before")
    @classmethod
904
905
906
    def check_tool_usage(cls, data):
        # if "tool_choice" is not specified but tools are provided,
        # default to "auto" tool_choice
907
        if "tool_choice" not in data and data.get("tools"):
908
909
            data["tool_choice"] = "auto"

910
        # if "tool_choice" is "none" -- no validation is needed for tools
911
912
913
        if "tool_choice" in data and data["tool_choice"] == "none":
            return data

914
        # if "tool_choice" is specified -- validation
915
        if "tool_choice" in data and data["tool_choice"] is not None:
916
            # ensure that if "tool choice" is specified, tools are present
917
            if "tools" not in data or data["tools"] is None:
918
                raise ValueError("When using `tool_choice`, `tools` must be set.")
919
920

            # make sure that tool choice is either a named tool
921
            # OR that it's set to "auto" or "required"
922
923
924
            if data["tool_choice"] not in ["auto", "required"] and not isinstance(
                data["tool_choice"], dict
            ):
925
                raise ValueError(
926
927
928
                    f"Invalid value for `tool_choice`: {data['tool_choice']}! "
                    'Only named tools, "none", "auto" or "required" '
                    "are supported."
929
                )
930

931
932
933
            # if tool_choice is "required" but the "tools" list is empty,
            # override the data to behave like "none" to align with
            # OpenAI’s behavior.
934
935
936
937
938
            if (
                data["tool_choice"] == "required"
                and isinstance(data["tools"], list)
                and len(data["tools"]) == 0
            ):
939
940
941
942
                data["tool_choice"] = "none"
                del data["tools"]
                return data

943
944
            # ensure that if "tool_choice" is specified as an object,
            # it matches a valid tool
945
946
            correct_usage_message = (
                'Correct usage: `{"type": "function",'
947
                ' "function": {"name": "my_function"}}`'
948
            )
949
950
            if isinstance(data["tool_choice"], dict):
                valid_tool = False
951
952
                function = data["tool_choice"].get("function")
                if not isinstance(function, dict):
953
                    raise ValueError(
954
                        f"Invalid value for `function`: `{function}` in "
955
956
                        f"`tool_choice`! {correct_usage_message}"
                    )
957
                if "name" not in function:
958
959
960
961
                    raise ValueError(
                        f"Expected field `name` in `function` in "
                        f"`tool_choice`! {correct_usage_message}"
                    )
962
                function_name = function["name"]
963
                if not isinstance(function_name, str) or len(function_name) == 0:
964
                    raise ValueError(
965
                        f"Invalid `name` in `function`: `{function_name}`"
966
967
                        f" in `tool_choice`! {correct_usage_message}"
                    )
968
                for tool in data["tools"]:
969
                    if tool["function"]["name"] == function_name:
970
971
972
973
974
                        valid_tool = True
                        break
                if not valid_tool:
                    raise ValueError(
                        "The tool specified in `tool_choice` does not match any"
975
976
                        " of the specified `tools`"
                    )
977
978
        return data

979
980
981
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
982
983
984
985
986
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
987
988
        return data

989
990
991
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
992
993
994
995
996
997
        if data.get("cache_salt") is not None and (
            not isinstance(data["cache_salt"], str) or not data["cache_salt"]
        ):
            raise ValueError(
                "Parameter 'cache_salt' must be a non-empty string if provided."
            )
998
999
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1000

1001
class CompletionRequest(OpenAIBaseModel):
1002
1003
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
1004
1005
1006
1007
1008
1009
1010
    model: str | None = None
    prompt: list[int] | list[list[int]] | str | list[str] | None = None
    echo: bool | None = False
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: int | None = None
    max_tokens: int | None = 16
1011
    n: int = 1
1012
1013
1014
1015
1016
1017
1018
1019
1020
    presence_penalty: float | None = 0.0
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    suffix: str | None = None
    temperature: float | None = None
    top_p: float | None = None
    user: str | None = None
1021

1022
    # --8<-- [start:completion-sampling-params]
1023
    use_beam_search: bool = False
1024
1025
1026
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
1027
    length_penalty: float = 1.0
1028
    stop_token_ids: list[int] | None = []
1029
1030
1031
1032
1033
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
1034
1035
1036
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    allowed_token_ids: list[int] | None = None
    prompt_logprobs: int | None = None
1037
    # --8<-- [end:completion-sampling-params]
1038

1039
    # --8<-- [start:completion-extra-params]
1040
    prompt_embeds: bytes | list[bytes] | None = None
1041
1042
    add_special_tokens: bool = Field(
        default=True,
1043
        description=(
1044
            "If true (the default), special tokens (e.g. BOS) will be added to "
1045
1046
            "the prompt."
        ),
1047
    )
1048
    response_format: AnyResponseFormat | None = Field(
1049
        default=None,
1050
1051
1052
1053
1054
        description=(
            "Similar to chat completion, this parameter specifies the format "
            "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
            ", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
        ),
1055
    )
1056
    structured_outputs: StructuredOutputsParams | None = Field(
1057
        default=None,
1058
        description="Additional kwargs for structured outputs",
1059
    )
1060
1061
1062
1063
1064
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1065
1066
            "if the served model does not use priority scheduling."
        ),
1067
    )
1068
    request_id: str = Field(
1069
        default_factory=random_uuid,
1070
1071
1072
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1073
1074
            "through out the inference process and return in response."
        ),
1075
    )
1076
    logits_processors: LogitsProcessors | None = Field(
1077
1078
1079
1080
1081
1082
1083
1084
1085
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
1086
1087
1088
            "{'param': 'value'}}."
        ),
    )
1089

1090
    return_tokens_as_token_ids: bool | None = Field(
1091
1092
1093
1094
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
1095
1096
1097
            "that are not JSON-encodable can be identified."
        ),
    )
1098
    return_token_ids: bool | None = Field(
1099
1100
1101
1102
1103
1104
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
1105
1106
1107
            "need to map generated text back to input tokens."
        ),
    )
1108

1109
    cache_salt: str | None = Field(
1110
1111
1112
1113
1114
1115
1116
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
1117
            "to 256 bit)."
1118
1119
        ),
    )
1120

1121
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
1122
        default=None,
1123
1124
        description="KVTransfer parameters used for disaggregated serving.",
    )
Robert Shaw's avatar
Robert Shaw committed
1125

1126
    vllm_xargs: dict[str, str | int | float] | None = Field(
1127
        default=None,
1128
1129
1130
1131
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
1132
1133
    )

1134
    # --8<-- [end:completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
1135

1136
1137
1138
1139
1140
    # Default sampling parameters for completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
1141
        "top_k": 0,
1142
1143
1144
1145
        "min_p": 0.0,
    }

    def to_beam_search_params(
1146
1147
        self,
        max_tokens: int,
1148
        default_sampling_params: dict | None = None,
1149
1150
1151
    ) -> BeamSearchParams:
        if default_sampling_params is None:
            default_sampling_params = {}
1152
        n = self.n if self.n is not None else 1
1153
1154
1155

        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get("temperature", 1.0)
1156
1157
1158
1159
1160
1161

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
1162
            length_penalty=self.length_penalty,
1163
1164
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
1165

1166
    def to_sampling_params(
1167
        self,
1168
        max_tokens: int,
1169
1170
        logits_processor_pattern: str | None,
        default_sampling_params: dict | None = None,
1171
    ) -> SamplingParams:
1172
1173
        if default_sampling_params is None:
            default_sampling_params = {}
1174

1175
1176
1177
1178
1179
1180
1181
1182
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
1183
1184
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
1185
1186
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
1187
1188
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
1189
1190
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
1191
1192
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
1193
1194
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
1195
1196
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
1197

1198
1199
1200
1201
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

1202
1203
        echo_without_generation = self.echo and self.max_tokens == 0

1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
        response_format = self.response_format
        if response_format is not None:
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
            if response_format.type == "json_object":
                self.structured_outputs.json_object = True
            elif response_format.type == "json_schema":
                json_schema = response_format.json_schema
1216
                assert json_schema is not None
1217
1218
1219
                self.structured_outputs.json = json_schema.json_schema
            elif response_format.type == "structural_tag":
                structural_tag = response_format
1220
                assert structural_tag is not None and isinstance(
1221
1222
1223
1224
1225
                    structural_tag,
                    (
                        LegacyStructuralTagResponseFormat,
                        StructuralTagResponseFormat,
                    ),
1226
1227
                )
                s_tag_obj = structural_tag.model_dump(by_alias=True)
1228
                self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
1229

1230
1231
1232
1233
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
1234
        return SamplingParams.from_optional(
1235
1236
1237
            n=self.n,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
1238
1239
1240
1241
1242
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
1243
            seed=self.seed,
1244
1245
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
1246
            logprobs=self.logprobs,
1247
            ignore_eos=self.ignore_eos,
1248
            max_tokens=max_tokens if not echo_without_generation else 1,
1249
            min_tokens=self.min_tokens,
1250
            prompt_logprobs=prompt_logprobs,
1251
            skip_special_tokens=self.skip_special_tokens,
1252
            spaces_between_special_tokens=self.spaces_between_special_tokens,
1253
            include_stop_str_in_output=self.include_stop_str_in_output,
1254
1255
1256
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
1257
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1258
1259
1260
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
1261
            structured_outputs=self.structured_outputs,
1262
            logit_bias=self.logit_bias,
Robert Shaw's avatar
Robert Shaw committed
1263
            allowed_token_ids=self.allowed_token_ids,
1264
            extra_args=extra_args or None,
1265
        )
1266

1267
1268
    @model_validator(mode="before")
    @classmethod
1269
    def check_structured_outputs_count(cls, data):
1270
        if data.get("structured_outputs", None) is None:
1271
1272
            return data

1273
        structured_outputs_kwargs = data["structured_outputs"]
1274
1275
        count = sum(
            structured_outputs_kwargs.get(k) is not None
1276
1277
            for k in ("json", "regex", "choice")
        )
1278
        if count > 1:
1279
            raise ValueError(
1280
                "You can only use one kind of constraints for structured "
1281
1282
                "outputs ('json', 'regex' or 'choice')."
            )
1283
1284
        return data

1285
1286
1287
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
1288
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
1289
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
1290
                raise ValueError(
1291
1292
                    "`prompt_logprobs` are not available when `stream=True`."
                )
1293

1294
            if prompt_logprobs < 0 and prompt_logprobs != -1:
1295
                raise ValueError("`prompt_logprobs` must be a positive value or -1.")
1296
1297
1298
        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
            raise ValueError("`logprobs` must be a positive value.")

1299
1300
        return data

1301
1302
1303
1304
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
1305
            raise ValueError("Stream options can only be defined when `stream=True`.")
1306

1307
1308
        return data

1309
1310
1311
    @model_validator(mode="before")
    @classmethod
    def validate_prompt_and_prompt_embeds(cls, data):
1312
1313
1314
        prompt = data.get("prompt")
        prompt_embeds = data.get("prompt_embeds")

1315
1316
1317
1318
        prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
        embeds_is_empty = prompt_embeds is None or (
            isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
        )
1319
1320

        if prompt_is_empty and embeds_is_empty:
1321
            raise ValueError(
1322
1323
1324
                "Either prompt or prompt_embeds must be provided and non-empty."
            )

1325
1326
        return data

1327
1328
1329
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
1330
1331
1332
1333
1334
1335
        if data.get("cache_salt") is not None and (
            not isinstance(data["cache_salt"], str) or not data["cache_salt"]
        ):
            raise ValueError(
                "Parameter 'cache_salt' must be a non-empty string if provided."
            )
1336
1337
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1338

1339
class CompletionLogProbs(OpenAIBaseModel):
1340
    text_offset: list[int] = Field(default_factory=list)
1341
    token_logprobs: list[float | None] = Field(default_factory=list)
1342
    tokens: list[str] = Field(default_factory=list)
1343
    top_logprobs: list[dict[str, float] | None] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
1344
1345


1346
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1347
1348
    index: int
    text: str
1349
1350
1351
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1352
1353
1354
1355
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1356
1357
            "including encountering the EOS token"
        ),
1358
    )
1359
1360
1361
    token_ids: list[int] | None = None  # For response
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None  # For prompt
Zhuohan Li's avatar
Zhuohan Li committed
1362
1363


1364
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1365
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
1366
    object: Literal["text_completion"] = "text_completion"
Zhuohan Li's avatar
Zhuohan Li committed
1367
1368
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1369
    choices: list[CompletionResponseChoice]
1370
1371
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1372
    usage: UsageInfo
1373
1374

    # vLLM-specific fields that are not in OpenAI spec
1375
    kv_transfer_params: dict[str, Any] | None = Field(
1376
1377
        default=None, description="KVTransfer parameters."
    )
Zhuohan Li's avatar
Zhuohan Li committed
1378
1379


1380
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1381
1382
    index: int
    text: str
1383
1384
1385
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1386
1387
1388
1389
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1390
1391
            "including encountering the EOS token"
        ),
1392
    )
1393
1394
    # not part of the OpenAI spec but for tracing the tokens
    # prompt tokens is put into choice to align with CompletionResponseChoice
1395
1396
    prompt_token_ids: list[int] | None = None
    token_ids: list[int] | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1397
1398


1399
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1400
1401
1402
1403
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1404
    choices: list[CompletionResponseStreamChoice]
1405
    usage: UsageInfo | None = Field(default=None)
1406
1407


1408
1409
1410
1411
1412
1413
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
1414
    id: str = Field(default_factory=make_tool_call_id)
1415
1416
1417
1418
    type: Literal["function"] = "function"
    function: FunctionCall


1419
class DeltaFunctionCall(BaseModel):
1420
1421
    name: str | None = None
    arguments: str | None = None
1422
1423
1424
1425


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
1426
1427
    id: str | None = None
    type: Literal["function"] | None = None
1428
    index: int
1429
    function: DeltaFunctionCall | None = None
1430
1431
1432
1433
1434
1435
1436


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
1437
    tool_calls: list[ToolCall]
1438
1439
1440

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
1441
    content: str | None = None
1442
1443


1444
class ChatMessage(OpenAIBaseModel):
1445
    role: str
1446
1447
1448
1449
1450
    content: str | None = None
    refusal: str | None = None
    annotations: OpenAIAnnotation | None = None
    audio: OpenAIChatCompletionAudio | None = None
    function_call: FunctionCall | None = None
1451
    tool_calls: list[ToolCall] = Field(default_factory=list)
1452

1453
    # vLLM-specific fields that are not in OpenAI spec
1454
    reasoning: str | None = None
1455
    reasoning_content: str | None = None
1456
1457
1458
1459
1460
1461
1462
    """Deprecated: use `reasoning` instead."""

    @model_validator(mode="after")
    def handle_deprecated_reasoning_content(self):
        """Copy reasoning to reasoning_content for backward compatibility."""
        self.reasoning_content = self.reasoning
        return self
1463

1464

1465
1466
1467
class ChatCompletionLogProb(OpenAIBaseModel):
    token: str
    logprob: float = -9999.0
1468
    bytes: list[int] | None = None
1469
1470
1471


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
1472
1473
    # Workaround: redefine fields name cache so that it's not
    # shared with the super class.
1474
    field_names: ClassVar[set[str] | None] = None
1475
    top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
1476
1477
1478


class ChatCompletionLogProbs(OpenAIBaseModel):
1479
    content: list[ChatCompletionLogProbsContent] | None = None
1480
1481


1482
class ChatCompletionResponseChoice(OpenAIBaseModel):
1483
1484
    index: int
    message: ChatMessage
1485
    logprobs: ChatCompletionLogProbs | None = None
1486
    # per OpenAI spec this is the default
1487
    finish_reason: str | None = "stop"
1488
    # not part of the OpenAI spec but included in vLLM for legacy reasons
1489
    stop_reason: int | str | None = None
1490
1491
    # not part of the OpenAI spec but is useful for tracing the tokens
    # in agent scenarios
1492
    token_ids: list[int] | None = None
1493
1494


1495
class ChatCompletionResponse(OpenAIBaseModel):
1496
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1497
    object: Literal["chat.completion"] = "chat.completion"
1498
1499
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1500
    choices: list[ChatCompletionResponseChoice]
1501
1502
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
1503
    usage: UsageInfo
1504
1505

    # vLLM-specific fields that are not in OpenAI spec
1506
1507
1508
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None
    kv_transfer_params: dict[str, Any] | None = Field(
1509
1510
        default=None, description="KVTransfer parameters."
    )
1511
1512


1513
class DeltaMessage(OpenAIBaseModel):
1514
1515
    role: str | None = None
    content: str | None = None
1516
    reasoning: str | None = None
1517
    reasoning_content: str | None = None
1518
    """Deprecated: use `reasoning` instead."""
1519
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
1520

1521
1522
1523
1524
1525
1526
    @model_validator(mode="after")
    def handle_deprecated_reasoning_content(self):
        """Copy reasoning to reasoning_content for backward compatibility."""
        self.reasoning_content = self.reasoning
        return self

1527

1528
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
1529
1530
    index: int
    delta: DeltaMessage
1531
1532
1533
    logprobs: ChatCompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = None
1534
    # not part of the OpenAI spec but for tracing the tokens
1535
    token_ids: list[int] | None = None
1536
1537


1538
class ChatCompletionStreamResponse(OpenAIBaseModel):
1539
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1540
    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
1541
1542
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1543
    choices: list[ChatCompletionResponseStreamChoice]
1544
    usage: UsageInfo | None = Field(default=None)
1545
    # not part of the OpenAI spec but for tracing the tokens
1546
    prompt_token_ids: list[int] | None = None
1547
1548


1549
1550
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
1551
1552
    finish_reason: str | None = None
    stop_reason: int | str | None = None
1553
1554
1555
1556
1557
1558
1559
1560


class TranscriptionStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
    object: Literal["transcription.chunk"] = "transcription.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranscriptionResponseStreamChoice]
1561
    usage: UsageInfo | None = Field(default=None)
1562
1563


1564
1565
class InputTokensDetails(OpenAIBaseModel):
    cached_tokens: int
1566
1567
    input_tokens_per_turn: list[int] = Field(default_factory=list)
    cached_tokens_per_turn: list[int] = Field(default_factory=list)
1568
1569
1570


class OutputTokensDetails(OpenAIBaseModel):
1571
1572
    reasoning_tokens: int = 0
    tool_output_tokens: int = 0
1573
1574
    output_tokens_per_turn: list[int] = Field(default_factory=list)
    tool_output_tokens_per_turn: list[int] = Field(default_factory=list)
1575
1576
1577
1578
1579
1580
1581
1582


class ResponseUsage(OpenAIBaseModel):
    input_tokens: int
    input_tokens_details: InputTokensDetails
    output_tokens: int
    output_tokens_details: OutputTokensDetails
    total_tokens: int
1583
1584


1585
1586
1587
1588
1589
1590
def serialize_message(msg):
    """
    Serializes a single message
    """
    if isinstance(msg, dict):
        return msg
1591
    elif hasattr(msg, "to_dict"):
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
        return msg.to_dict()
    else:
        # fallback to pyandic dump
        return msg.model_dump_json()


def serialize_messages(msgs):
    """
    Serializes multiple messages
    """
    return [serialize_message(msg) for msg in msgs] if msgs else None


1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
class ResponseRawMessageAndToken(OpenAIBaseModel):
    """Class to show the raw message.
    If message / tokens diverge, tokens is the source of truth"""

    message: str
    tokens: list[int]
    type: Literal["raw_message_tokens"] = "raw_message_tokens"


ResponseInputOutputMessage: TypeAlias = (
    list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken]
)


1619
1620
1621
1622
class ResponsesResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
    created_at: int = Field(default_factory=lambda: int(time.time()))
    # error: Optional[ResponseError] = None
1623
1624
1625
    incomplete_details: IncompleteDetails | None = None
    instructions: str | None = None
    metadata: Metadata | None = None
1626
1627
    model: str
    object: Literal["response"] = "response"
1628
    output: list[ResponseOutputItem]
1629
1630
1631
1632
1633
1634
1635
    parallel_tool_calls: bool
    temperature: float
    tool_choice: ToolChoice
    tools: list[Tool]
    top_p: float
    background: bool
    max_output_tokens: int
1636
1637
1638
1639
    max_tool_calls: int | None = None
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
1640
1641
    service_tier: Literal["auto", "default", "flex", "scale", "priority"]
    status: ResponseStatus
1642
1643
    text: ResponseTextConfig | None = None
    top_logprobs: int | None = None
1644
    truncation: Literal["auto", "disabled"]
1645
1646
    usage: ResponseUsage | None = None
    user: str | None = None
1647

1648
1649
1650
1651
    # --8<-- [start:responses-extra-params]
    # These are populated when enable_response_messages is set to True
    # NOTE: custom serialization is needed
    # see serialize_input_messages and serialize_output_messages
1652
1653
    input_messages: ResponseInputOutputMessage | None = None
    output_messages: ResponseInputOutputMessage | None = None
1654
1655
1656
1657
1658
1659
1660
    # --8<-- [end:responses-extra-params]

    # NOTE: openAI harmony doesn't serialize TextContent properly,
    # TODO: this fixes for TextContent, but need to verify for tools etc
    # https://github.com/openai/harmony/issues/78
    @field_serializer("output_messages", when_used="json")
    def serialize_output_messages(self, msgs, _info):
1661
        return serialize_messages(msgs)
1662
1663
1664
1665
1666

    # NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it
    # https://github.com/openai/harmony/issues/78
    @field_serializer("input_messages", when_used="json")
    def serialize_input_messages(self, msgs, _info):
1667
        return serialize_messages(msgs)
1668

1669
1670
1671
1672
1673
1674
1675
1676
1677
    @classmethod
    def from_request(
        cls,
        request: ResponsesRequest,
        sampling_params: SamplingParams,
        model_name: str,
        created_time: int,
        output: list[ResponseOutputItem],
        status: ResponseStatus,
1678
        usage: ResponseUsage | None = None,
1679
1680
        input_messages: ResponseInputOutputMessage | None = None,
        output_messages: ResponseInputOutputMessage | None = None,
1681
    ) -> "ResponsesResponse":
1682
        incomplete_details: IncompleteDetails | None = None
1683
1684
        if status == "incomplete":
            incomplete_details = IncompleteDetails(reason="max_output_tokens")
1685
1686
1687
        # TODO: implement the other reason for incomplete_details,
        # which is content_filter
        # incomplete_details = IncompleteDetails(reason='content_filter')
1688
1689
1690
        return cls(
            id=request.request_id,
            created_at=created_time,
1691
            incomplete_details=incomplete_details,
1692
1693
1694
1695
            instructions=request.instructions,
            metadata=request.metadata,
            model=model_name,
            output=output,
1696
1697
            input_messages=input_messages,
            output_messages=output_messages,
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
            parallel_tool_calls=request.parallel_tool_calls,
            temperature=sampling_params.temperature,
            tool_choice=request.tool_choice,
            tools=request.tools,
            top_p=sampling_params.top_p,
            background=request.background,
            max_output_tokens=sampling_params.max_tokens,
            max_tool_calls=request.max_tool_calls,
            previous_response_id=request.previous_response_id,
            prompt=request.prompt,
            reasoning=request.reasoning,
            service_tier=request.service_tier,
            status=status,
            text=request.text,
            top_logprobs=sampling_params.logprobs,
            truncation=request.truncation,
            user=request.user,
            usage=usage,
        )


1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartDoneEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.done"]
    """The type of the event. Always `response.reasoning_part.done`."""


# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartAddedEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.added"]
    """The type of the event. Always `response.reasoning_part.added`."""


1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
# vLLM Streaming Events
# Note: we override the response type with the vLLM ResponsesResponse type
class ResponseCompletedEvent(OpenAIResponseCompletedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseCreatedEvent(OpenAIResponseCreatedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseInProgressEvent(OpenAIResponseInProgressEvent):
    response: ResponsesResponse  # type: ignore[override]


1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
StreamingResponsesResponse: TypeAlias = (
    ResponseCreatedEvent
    | ResponseInProgressEvent
    | ResponseCompletedEvent
    | ResponseOutputItemAddedEvent
    | ResponseOutputItemDoneEvent
    | ResponseContentPartAddedEvent
    | ResponseContentPartDoneEvent
    | ResponseReasoningTextDeltaEvent
    | ResponseReasoningTextDoneEvent
    | ResponseReasoningPartAddedEvent
    | ResponseReasoningPartDoneEvent
    | ResponseCodeInterpreterCallInProgressEvent
    | ResponseCodeInterpreterCallCodeDeltaEvent
    | ResponseWebSearchCallInProgressEvent
    | ResponseWebSearchCallSearchingEvent
    | ResponseWebSearchCallCompletedEvent
    | ResponseCodeInterpreterCallCodeDoneEvent
    | ResponseCodeInterpreterCallInterpretingEvent
    | ResponseCodeInterpreterCallCompletedEvent
1797
1798
1799
1800
    | ResponseMcpCallArgumentsDeltaEvent
    | ResponseMcpCallArgumentsDoneEvent
    | ResponseMcpCallInProgressEvent
    | ResponseMcpCallCompletedEvent
1801
)
1802

1803

1804
class TokenizeCompletionRequest(OpenAIBaseModel):
1805
    model: str | None = None
1806
1807
    prompt: str

1808
1809
1810
1811
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
1812
1813
            "the prompt."
        ),
1814
    )
1815
    return_token_strs: bool | None = Field(
1816
        default=False,
1817
1818
1819
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
1820
    )
1821
1822
1823


class TokenizeChatRequest(OpenAIBaseModel):
1824
    model: str | None = None
1825
    messages: list[ChatCompletionMessageParam]
1826

1827
1828
    add_generation_prompt: bool = Field(
        default=True,
1829
1830
1831
1832
1833
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
1834
    )
1835
    return_token_strs: bool | None = Field(
1836
        default=False,
1837
1838
1839
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
1840
    )
1841
1842
    continue_final_message: bool = Field(
        default=False,
1843
1844
1845
1846
1847
1848
1849
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
1850
1851
1852
1853
1854
1855
1856
1857
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
1858
1859
            "default)."
        ),
1860
    )
1861
    chat_template: str | None = Field(
1862
1863
1864
1865
1866
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1867
1868
            "does not define one."
        ),
1869
    )
1870
    chat_template_kwargs: dict[str, Any] | None = Field(
1871
        default=None,
1872
1873
        description=(
            "Additional keyword args to pass to the template renderer. "
1874
1875
            "Will be accessible by the chat template."
        ),
1876
    )
1877
    mm_processor_kwargs: dict[str, Any] | None = Field(
1878
1879
1880
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
1881
    tools: list[ChatCompletionToolsParam] | None = Field(
1882
1883
1884
        default=None,
        description=("A list of tools the model may call."),
    )
1885

1886
1887
1888
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1889
1890
1891
1892
1893
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1894
1895
        return data

1896

1897
TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest
1898
1899
1900
1901
1902


class TokenizeResponse(OpenAIBaseModel):
    count: int
    max_model_len: int
1903
    tokens: list[int]
1904
    token_strs: list[str] | None = None
1905
1906
1907


class DetokenizeRequest(OpenAIBaseModel):
1908
    model: str | None = None
1909
    tokens: list[int]
1910
1911
1912
1913


class DetokenizeResponse(OpenAIBaseModel):
    prompt: str
1914
1915


1916
1917
class TokenizerInfoResponse(OpenAIBaseModel):
    """
1918
    Response containing tokenizer configuration
1919
1920
1921
1922
1923
1924
1925
    equivalent to tokenizer_config.json
    """

    model_config = ConfigDict(extra="allow")
    tokenizer_class: str


1926
class LoadLoRAAdapterRequest(BaseModel):
1927
1928
1929
1930
    lora_name: str
    lora_path: str


1931
class UnloadLoRAAdapterRequest(BaseModel):
1932
    lora_name: str
1933
    lora_int_id: int | None = Field(default=None)
1934
1935
1936


## Protocols for Audio
1937
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
1938
1939
1940
1941


class TranscriptionRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
1942
    # https://platform.openai.com/docs/api-reference/audio/createTranscription
1943
1944
1945
1946
1947
1948
1949

    file: UploadFile
    """
    The audio file object (not file name) to transcribe, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

1950
    model: str | None = None
1951
1952
1953
    """ID of the model to use.
    """

1954
    language: str | None = None
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
    """The language of the input audio.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy and latency.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    ## TODO (varun) : Support if set to 0, certain thresholds are met !!

1978
    timestamp_granularities: list[Literal["word", "segment"]] = Field(
1979
1980
        alias="timestamp_granularities[]", default=[]
    )
1981
1982
1983
1984
1985
1986
1987
1988
    """The timestamp granularities to populate for this transcription.

    `response_format` must be set `verbose_json` to use timestamp granularities.
    Either or both of these options are supported: `word`, or `segment`. Note:
    There is no additional latency for segment timestamps, but generating word
    timestamps incurs additional latency.
    """

1989
    stream: bool | None = False
1990
    """When set, it will enable output to be streamed in a similar fashion
1991
    as the Chat Completion endpoint.
1992
    """
1993
    # --8<-- [start:transcription-extra-params]
1994
    # Flattened stream option to simplify form data.
1995
1996
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
1997

1998
    vllm_xargs: dict[str, str | int | float] | None = Field(
1999
        default=None,
2000
2001
2002
2003
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
2004
    )
2005
    # --8<-- [end:transcription-extra-params]
2006

2007
    to_language: str | None = None
2008
2009
    """The language of the output audio we transcribe to.

2010
    Please note that this is not currently used by supported models at this
2011
2012
2013
    time, but it is a placeholder for future use, matching translation api.
    """

2014
    # --8<-- [start:transcription-sampling-params]
2015
2016
2017
2018
2019
2020
2021
2022
2023
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """

2024
    top_p: float | None = None
2025
    """Enables nucleus (top-p) sampling, where tokens are selected from the
2026
2027
2028
    smallest possible set whose cumulative probability exceeds `p`.
    """

2029
    top_k: int | None = None
2030
2031
    """Limits sampling to the `k` most probable tokens at each step."""

2032
    min_p: float | None = None
2033
    """Filters out tokens with a probability lower than `min_p`, ensuring a
2034
2035
2036
    minimum likelihood threshold during sampling.
    """

2037
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2038
2039
    """The seed to use for sampling."""

2040
    frequency_penalty: float | None = 0.0
2041
2042
    """The frequency penalty to use for sampling."""

2043
    repetition_penalty: float | None = None
2044
2045
    """The repetition penalty to use for sampling."""

2046
    presence_penalty: float | None = 0.0
2047
    """The presence penalty to use for sampling."""
2048
    # --8<-- [end:transcription-sampling-params]
2049

2050
2051
    # Default sampling parameters for transcription requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
2052
2053
2054
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
2055
        "top_k": 0,
2056
        "min_p": 0.0,
2057
2058
2059
    }

    def to_sampling_params(
2060
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2061
    ) -> SamplingParams:
2062
2063
2064
2065
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
2066

2067
2068
2069
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2070
2071
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2072
2073
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
2074
2075
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
2076
2077
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
2078
2079
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
2080
2081
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
2082
2083
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
2084
2085
2086
2087

        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )

        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=repetition_penalty,
            presence_penalty=self.presence_penalty,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
            extra_args=self.vllm_xargs,
        )
2106
2107
2108

    @model_validator(mode="before")
    @classmethod
2109
2110
2111
2112
2113
2114
2115
    def validate_transcription_request(cls, data):
        if isinstance(data.get("file"), str):
            raise HTTPException(
                status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
                detail="Expected 'file' to be a file-like object, not 'str'.",
            )

2116
2117
2118
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2119
            raise ValueError("Stream options can only be defined when `stream=True`.")
2120
2121

        return data
2122
2123
2124


# Transcription response objects
2125
2126
2127
2128
2129
class TranscriptionUsageAudio(OpenAIBaseModel):
    type: Literal["duration"] = "duration"
    seconds: int


2130
2131
2132
class TranscriptionResponse(OpenAIBaseModel):
    text: str
    """The transcribed text."""
2133
    usage: TranscriptionUsageAudio
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150


class TranscriptionWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranscriptionSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

2151
    avg_logprob: float | None = None
2152
2153
2154
2155
2156
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

2157
    compression_ratio: float | None = None
2158
2159
2160
2161
2162
2163
2164
2165
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

2166
    no_speech_prob: float | None = None
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

2185
    tokens: list[int]
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
    """Array of token IDs for the text content."""


class TranscriptionResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The transcribed text."""

2199
    segments: list[TranscriptionSegment] | None = None
2200
2201
    """Segments of the transcribed text and their corresponding details."""

2202
    words: list[TranscriptionWord] | None = None
2203
    """Extracted words and their corresponding timestamps."""
2204
2205


2206
2207
2208
2209
2210
TranscriptionResponseVariant: TypeAlias = (
    TranscriptionResponse | TranscriptionResponseVerbose
)


2211
2212
class TranslationResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
2213
2214
    finish_reason: str | None = None
    stop_reason: int | str | None = None
2215
2216
2217
2218
2219
2220
2221
2222


class TranslationStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
    object: Literal["translation.chunk"] = "translation.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranslationResponseStreamChoice]
2223
    usage: UsageInfo | None = Field(default=None)
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235


class TranslationRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/audio/createTranslation

    file: UploadFile
    """
    The audio file object (not file name) to translate, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2236
    model: str | None = None
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
    """ID of the model to use.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    # TODO support additional sampling parameters
    # --8<-- [start:translation-sampling-params]
2256
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2257
2258
    """The seed to use for sampling."""

2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """
    # --8<-- [end:translation-sampling-params]

    # --8<-- [start:translation-extra-params]
2270
    language: str | None = None
2271
2272
2273
2274
2275
2276
2277
    """The language of the input audio we translate from.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy.
    """

2278
    to_language: str | None = None
2279
2280
2281
2282
2283
2284
2285
    """The language of the input audio we translate to.

    Please note that this is not supported by all models, refer to the specific
    model documentation for more details.
    For instance, Whisper only supports `to_language=en`.
    """

2286
    stream: bool | None = False
2287
    """Custom field not present in the original OpenAI definition. When set,
2288
    it will enable output to be streamed in a similar fashion as the Chat
2289
    Completion endpoint.
2290
2291
    """
    # Flattened stream option to simplify form data.
2292
2293
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
2294
2295
2296
2297
2298
2299
2300
2301
    # --8<-- [end:translation-extra-params]

    # Default sampling parameters for translation requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "temperature": 0,
    }

    def to_sampling_params(
2302
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2303
    ) -> SamplingParams:
2304
2305
2306
2307
2308
2309
2310
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2311
2312
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2313

2314
2315
2316
2317
2318
2319
2320
2321
        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
        )
2322
2323
2324
2325
2326
2327
2328

    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2329
            raise ValueError("Stream options can only be defined when `stream=True`.")
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354

        return data


# Translation response objects
class TranslationResponse(OpenAIBaseModel):
    text: str
    """The translated text."""


class TranslationWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranslationSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

2355
    avg_logprob: float | None = None
2356
2357
2358
2359
2360
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

2361
    compression_ratio: float | None = None
2362
2363
2364
2365
2366
2367
2368
2369
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

2370
    no_speech_prob: float | None = None
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

    tokens: list[int]
    """Array of token IDs for the text content."""


class TranslationResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The translated text."""

2403
    segments: list[TranslationSegment] | None = None
2404
2405
    """Segments of the translated text and their corresponding details."""

2406
    words: list[TranslationWord] | None = None
2407
    """Extracted words and their corresponding timestamps."""
2408
2409


2410
2411
2412
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose


2413
2414
2415
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
    request_id: str = Field(
2416
        default_factory=random_uuid,
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    token_ids: list[int]
    """The token ids to generate text from."""

    # features: MultiModalFeatureSpec
    # TODO (NickLucche): implement once Renderer work is completed
    features: str | None = None
    """The processed MM inputs for the model."""

    sampling_params: SamplingParams
    """The sampling parameters for the model."""

    model: str | None = None

    stream: bool | None = False
    stream_options: StreamOptions | None = None
    cache_salt: str | None = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit)."
        ),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."
        ),
    )
    kv_transfer_params: dict[str, Any] | None = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.",
    )


class GenerateResponseChoice(BaseModel):
    index: int
    logprobs: ChatCompletionLogProbs | None = None
    # per OpenAI spec this is the default
    finish_reason: str | None = "stop"
    token_ids: list[int] | None = None


class GenerateResponse(BaseModel):
    request_id: str = Field(
2473
        default_factory=random_uuid,
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    choices: list[GenerateResponseChoice]

    prompt_logprobs: list[dict[int, Logprob] | None] | None = None

    kv_transfer_params: dict[str, Any] | None = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.",
    )