protocol.py 89.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
6
import json
Zhuohan Li's avatar
Zhuohan Li committed
7
import time
8
from http import HTTPStatus
9
from typing import Annotated, Any, ClassVar, Literal, TypeAlias
Zhuohan Li's avatar
Zhuohan Li committed
10

11
import regex as re
12
import torch
13
from fastapi import HTTPException, UploadFile
14
from openai.types.chat.chat_completion_audio import (
15
16
17
    ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
18
19
20
21
22
from openai.types.responses import (
    ResponseCodeInterpreterCallCodeDeltaEvent,
    ResponseCodeInterpreterCallCodeDoneEvent,
    ResponseCodeInterpreterCallCompletedEvent,
    ResponseCodeInterpreterCallInProgressEvent,
23
24
25
26
27
    ResponseCodeInterpreterCallInterpretingEvent,
    ResponseContentPartAddedEvent,
    ResponseContentPartDoneEvent,
    ResponseFunctionToolCall,
    ResponseInputItemParam,
28
29
30
31
    ResponseMcpCallArgumentsDeltaEvent,
    ResponseMcpCallArgumentsDoneEvent,
    ResponseMcpCallCompletedEvent,
    ResponseMcpCallInProgressEvent,
32
33
34
35
36
37
38
39
40
41
42
    ResponseOutputItem,
    ResponseOutputItemAddedEvent,
    ResponseOutputItemDoneEvent,
    ResponsePrompt,
    ResponseReasoningTextDeltaEvent,
    ResponseReasoningTextDoneEvent,
    ResponseStatus,
    ResponseWebSearchCallCompletedEvent,
    ResponseWebSearchCallInProgressEvent,
    ResponseWebSearchCallSearchingEvent,
)
43
from openai.types.responses import (
44
45
46
    ResponseCompletedEvent as OpenAIResponseCompletedEvent,
)
from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent
47
from openai.types.responses import (
48
49
    ResponseInProgressEvent as OpenAIResponseInProgressEvent,
)
50
from openai.types.responses.response_reasoning_item import (
51
52
    Content as ResponseReasoningTextContent,
)
53
from openai_harmony import Message as OpenAIHarmonyMessage
54
55
56
57
58

# Backward compatibility for OpenAI client versions
try:  # For older openai versions (< 1.100.0)
    from openai.types.responses import ResponseTextConfig
except ImportError:  # For newer openai versions (>= 1.100.0)
59
    from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
60

61

62
from openai.types.responses.response import IncompleteDetails, ToolChoice
63
64
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
65
66
67
68
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
69
    ValidationError,
70
    field_serializer,
71
72
    model_validator,
)
Zhuohan Li's avatar
Zhuohan Li committed
73

74
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
75
from vllm.exceptions import VLLMValidationError
76
from vllm.logger import init_logger
77
from vllm.logprobs import Logprob
78
79
80
81
82
83
from vllm.sampling_params import (
    BeamSearchParams,
    RequestOutputKind,
    SamplingParams,
    StructuredOutputsParams,
)
84
85
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
86

87
88
logger = init_logger(__name__)

89
_LONG_INFO = torch.iinfo(torch.long)
90

Zhuohan Li's avatar
Zhuohan Li committed
91

92
class OpenAIBaseModel(BaseModel):
93
94
95
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

96
    # Cache class field names
97
    field_names: ClassVar[set[str] | None] = None
98

99
    @model_validator(mode="wrap")
100
    @classmethod
101
102
103
104
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
105
106
        field_names = cls.field_names
        if field_names is None:
107
108
109
110
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
111
                if alias := getattr(field, "alias", None):
112
113
114
115
116
117
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
118
                "The following fields were present in the request but ignored: %s",
119
120
                data.keys() - field_names,
            )
121
        return result
122
123


124
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
125
126
    message: str
    type: str
127
    param: str | None = None
128
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
129
130


131
132
133
134
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


135
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
136
137
138
139
140
141
142
143
144
145
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
146
    group: str | None = None
147
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
148
149


150
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
151
152
153
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
154
    owned_by: str = "vllm"
155
156
157
    root: str | None = None
    parent: str | None = None
    max_model_len: int | None = None
158
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
159
160


161
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
162
    object: str = "list"
163
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
164
165


166
class PromptTokenUsageInfo(OpenAIBaseModel):
167
    cached_tokens: int | None = None
168
169


170
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
171
172
    prompt_tokens: int = 0
    total_tokens: int = 0
173
174
    completion_tokens: int | None = 0
    prompt_tokens_details: PromptTokenUsageInfo | None = None
Zhuohan Li's avatar
Zhuohan Li committed
175
176


177
178
class RequestResponseMetadata(BaseModel):
    request_id: str
179
    final_usage_info: UsageInfo | None = None
180
181


182
183
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
184
    description: str | None = None
185
186
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
187
188
    json_schema: dict[str, Any] | None = Field(default=None, alias="schema")
    strict: bool | None = None
189
190


191
class LegacyStructuralTag(OpenAIBaseModel):
192
193
194
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
195
    structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema")
196
197
198
    end: str


199
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
200
    type: Literal["structural_tag"]
201
    structures: list[LegacyStructuralTag]
202
203
204
    triggers: list[str]


205
206
207
208
209
210
211
212
213
214
class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    format: Any


AnyStructuralTagResponseFormat: TypeAlias = (
    LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
)


215
class ResponseFormat(OpenAIBaseModel):
216
    # type must be "json_schema", "json_object", or "text"
217
    type: Literal["text", "json_object", "json_schema"]
218
    json_schema: JsonSchemaResponseFormat | None = None
219
220


221
222
223
AnyResponseFormat: TypeAlias = (
    ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
)
224
225


226
class StreamOptions(OpenAIBaseModel):
227
228
    include_usage: bool | None = True
    continuous_usage_stats: bool | None = False
229
230


231
232
class FunctionDefinition(OpenAIBaseModel):
    name: str
233
234
    description: str | None = None
    parameters: dict[str, Any] | None = None
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250


class ChatCompletionToolsParam(OpenAIBaseModel):
    type: Literal["function"] = "function"
    function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
    name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
    function: ChatCompletionNamedFunction
    type: Literal["function"] = "function"


251
252
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
253
254
class LogitsProcessorConstructor(BaseModel):
    qualname: str
255
256
    args: list[Any] | None = None
    kwargs: dict[str, Any] | None = None
257

258
259
    model_config = ConfigDict(extra="forbid")

260

261
LogitsProcessors = list[str | LogitsProcessorConstructor]
262
263


264
def get_logits_processors(
265
266
    processors: LogitsProcessors | None, pattern: str | None
) -> list[Any] | None:
267
268
269
    if processors and pattern:
        logits_processors = []
        for processor in processors:
270
            qualname = processor if isinstance(processor, str) else processor.qualname
271
272
273
274
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
275
276
                    "for more information."
                )
277
278
279
280
281
282
283
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
284
285
286
                logits_processor = logits_processor(
                    *processor.args or [], **processor.kwargs or {}
                )
287
288
289
290
291
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
292
            "server. See --logits-processor-pattern engine argument "
293
294
            "for more information."
        )
295
296
297
    return None


298
ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem
299
300


301
302
303
class ResponsesRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/responses/create
304
305
    background: bool | None = False
    include: (
306
307
308
309
310
311
312
313
314
315
        list[
            Literal[
                "code_interpreter_call.outputs",
                "computer_call_output.output.image_url",
                "file_search_call.results",
                "message.input_image.image_url",
                "message.output_text.logprobs",
                "reasoning.encrypted_content",
            ],
        ]
316
317
318
319
320
321
322
323
        | None
    ) = None
    input: str | list[ResponseInputOutputItem]
    instructions: str | None = None
    max_output_tokens: int | None = None
    max_tool_calls: int | None = None
    metadata: Metadata | None = None
    model: str | None = None
324
    logit_bias: dict[str, float] | None = None
325
326
327
328
    parallel_tool_calls: bool | None = True
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
329
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
330
331
332
333
    store: bool | None = True
    stream: bool | None = False
    temperature: float | None = None
    text: ResponseTextConfig | None = None
334
335
    tool_choice: ToolChoice = "auto"
    tools: list[Tool] = Field(default_factory=list)
336
337
    top_logprobs: int | None = 0
    top_p: float | None = None
338
    top_k: int | None = None
339
340
    truncation: Literal["auto", "disabled"] | None = "disabled"
    user: str | None = None
341
342
343
344
345
346
347

    # --8<-- [start:responses-extra-params]
    request_id: str = Field(
        default_factory=lambda: f"resp_{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
348
349
            "through out the inference process and return in response."
        ),
350
    )
351
    mm_processor_kwargs: dict[str, Any] | None = Field(
352
353
354
355
356
357
358
359
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
360
361
            "if the served model does not use priority scheduling."
        ),
362
    )
363
    cache_salt: str | None = Field(
364
365
366
367
368
369
370
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
371
            "to 256 bit)."
372
373
        ),
    )
374
375
376
377
378

    enable_response_messages: bool = Field(
        default=False,
        description=(
            "Dictates whether or not to return messages as part of the "
379
            "response object. Currently only supported for"
380
381
382
            "non-background and gpt-oss only. "
        ),
    )
383
384
385
386
387
    # similar to input_messages / output_messages in ResponsesResponse
    # we take in previous_input_messages (ie in harmony format)
    # this cannot be used in conjunction with previous_response_id
    # TODO: consider supporting non harmony messages as well
    previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
388
389
390
391
392
    # --8<-- [end:responses-extra-params]

    _DEFAULT_SAMPLING_PARAMS = {
        "temperature": 1.0,
        "top_p": 1.0,
393
        "top_k": 0,
394
395
396
397
398
    }

    def to_sampling_params(
        self,
        default_max_tokens: int,
399
        default_sampling_params: dict | None = None,
400
401
402
403
404
405
406
407
408
    ) -> SamplingParams:
        if self.max_output_tokens is None:
            max_tokens = default_max_tokens
        else:
            max_tokens = min(self.max_output_tokens, default_max_tokens)

        default_sampling_params = default_sampling_params or {}
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
409
410
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
411
412
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
413
414
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
415
416
417
418
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
419
        stop_token_ids = default_sampling_params.get("stop_token_ids")
420
421

        # Structured output
422
        structured_outputs = None
423
424
        if self.text is not None and self.text.format is not None:
            response_format = self.text.format
425
426
427
428
            if (
                response_format.type == "json_schema"
                and response_format.schema_ is not None
            ):
429
                structured_outputs = StructuredOutputsParams(
430
431
                    json=response_format.schema_
                )
432
433
434
435
436
437
438
            elif response_format.type == "json_object":
                raise NotImplementedError("json_object is not supported")

        # TODO: add more parameters
        return SamplingParams.from_optional(
            temperature=temperature,
            top_p=top_p,
439
            top_k=top_k,
440
            max_tokens=max_tokens,
441
            logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
442
            stop_token_ids=stop_token_ids,
443
444
445
            output_kind=(
                RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
            ),
446
            structured_outputs=structured_outputs,
447
            logit_bias=self.logit_bias,
448
            skip_clone=True,  # Created fresh per request, safe to skip clone
449
450
        )

451
452
453
454
    def is_include_output_logprobs(self) -> bool:
        """Check if the request includes output logprobs."""
        if self.include is None:
            return False
455
456
457
458
        return (
            isinstance(self.include, list)
            and "message.output_text.logprobs" in self.include
        )
459

460
461
462
463
464
    @model_validator(mode="before")
    def validate_background(cls, data):
        if not data.get("background"):
            return data
        if not data.get("store", True):
465
            raise ValueError("background can only be used when `store` is true")
466
467
468
469
470
        return data

    @model_validator(mode="before")
    def validate_prompt(cls, data):
        if data.get("prompt") is not None:
471
472
473
            raise VLLMValidationError(
                "prompt template is not supported", parameter="prompt"
            )
474
475
        return data

476
477
    @model_validator(mode="before")
    def check_cache_salt_support(cls, data):
478
479
480
481
482
483
        if data.get("cache_salt") is not None and (
            not isinstance(data["cache_salt"], str) or not data["cache_salt"]
        ):
            raise ValueError(
                "Parameter 'cache_salt' must be a non-empty string if provided."
            )
484
485
        return data

486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    @model_validator(mode="before")
    def function_call_parsing(cls, data):
        """Parse function_call dictionaries into ResponseFunctionToolCall objects.
        This ensures Pydantic can properly resolve union types in the input field.
        Function calls provided as dicts are converted to ResponseFunctionToolCall
        objects before validation, while invalid structures are left for Pydantic
        to reject with appropriate error messages.
        """

        input_data = data.get("input")

        # Early return for None, strings, or bytes
        # (strings are iterable but shouldn't be processed)
        if input_data is None or isinstance(input_data, (str, bytes)):
            return data

        # Convert iterators (like ValidatorIterator) to list
        if not isinstance(input_data, list):
            try:
                input_data = list(input_data)
            except TypeError:
                # Not iterable, leave as-is for Pydantic to handle
                return data

        processed_input = []
        for item in input_data:
            if isinstance(item, dict) and item.get("type") == "function_call":
                try:
                    processed_input.append(ResponseFunctionToolCall(**item))
                except ValidationError:
                    # Let Pydantic handle validation for malformed function calls
                    logger.debug(
                        "Failed to parse function_call to ResponseFunctionToolCall, "
                        "leaving for Pydantic validation"
                    )
                    processed_input.append(item)
            else:
                processed_input.append(item)

        data["input"] = processed_input
        return data

528

529
class ChatCompletionRequest(OpenAIBaseModel):
530
531
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
532
    messages: list[ChatCompletionMessageParam]
533
534
535
536
537
538
    model: str | None = None
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: bool | None = False
    top_logprobs: int | None = 0
    max_tokens: int | None = Field(
539
        default=None,
540
541
        deprecated="max_tokens is deprecated in favor of "
        "the max_completion_tokens field",
542
    )
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    max_completion_tokens: int | None = None
    n: int | None = 1
    presence_penalty: float | None = 0.0
    response_format: AnyResponseFormat | None = None
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    temperature: float | None = None
    top_p: float | None = None
    tools: list[ChatCompletionToolsParam] | None = None
    tool_choice: (
        Literal["none"]
        | Literal["auto"]
        | Literal["required"]
        | ChatCompletionNamedToolChoiceParam
        | None
    ) = "none"
    reasoning_effort: Literal["low", "medium", "high"] | None = None
562
    include_reasoning: bool = True
563
    parallel_tool_calls: bool | None = True
564

565
    # NOTE this will be ignored by vLLM
566
    user: str | None = None
567

568
    # --8<-- [start:chat-completion-sampling-params]
569
    use_beam_search: bool = False
570
571
572
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
573
    length_penalty: float = 1.0
574
    stop_token_ids: list[int] | None = []
575
576
577
578
579
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
580
581
582
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    prompt_logprobs: int | None = None
    allowed_token_ids: list[int] | None = None
583
    bad_words: list[str] = Field(default_factory=list)
584
    # --8<-- [end:chat-completion-sampling-params]
585

586
    # --8<-- [start:chat-completion-extra-params]
587
    echo: bool = Field(
588
589
590
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
591
592
            "if they belong to the same role."
        ),
593
    )
594
    add_generation_prompt: bool = Field(
595
        default=True,
596
597
598
599
600
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
601
    )
602
603
    continue_final_message: bool = Field(
        default=False,
604
605
606
607
608
609
610
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
611
    )
612
    add_special_tokens: bool = Field(
613
614
615
616
617
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
618
            "special tokens so this should be set to false (as is the "
619
620
            "default)."
        ),
621
    )
622
    documents: list[dict[str, str]] | None = Field(
623
        default=None,
624
625
626
627
628
629
630
        description=(
            "A list of dicts representing documents that will be accessible to "
            "the model if it is performing RAG (retrieval-augmented generation)."
            " If the template does not support RAG, this argument will have no "
            "effect. We recommend that each document should be a dict containing "
            '"title" and "text" keys.'
        ),
631
    )
632
    chat_template: str | None = Field(
633
634
635
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
636
637
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
638
639
            "does not define one."
        ),
640
    )
641
    chat_template_kwargs: dict[str, Any] | None = Field(
642
        default=None,
643
644
        description=(
            "Additional keyword args to pass to the template renderer. "
645
646
            "Will be accessible by the chat template."
        ),
647
    )
648
    mm_processor_kwargs: dict[str, Any] | None = Field(
649
650
651
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
652
    structured_outputs: StructuredOutputsParams | None = Field(
653
        default=None,
654
        description="Additional kwargs for structured outputs",
655
    )
656
657
658
659
660
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
661
662
            "if the served model does not use priority scheduling."
        ),
663
    )
664
    request_id: str = Field(
665
        default_factory=random_uuid,
666
667
668
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
669
670
            "through out the inference process and return in response."
        ),
671
    )
672
    logits_processors: LogitsProcessors | None = Field(
673
674
675
676
677
678
679
680
681
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
682
683
684
            "{'param': 'value'}}."
        ),
    )
685
    return_tokens_as_token_ids: bool | None = Field(
686
687
688
689
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
690
691
692
            "that are not JSON-encodable can be identified."
        ),
    )
693
    return_token_ids: bool | None = Field(
694
695
696
697
698
699
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
700
701
702
            "need to map generated text back to input tokens."
        ),
    )
703
    cache_salt: str | None = Field(
704
705
706
707
708
709
710
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
711
            "to 256 bit)."
712
713
        ),
    )
714
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
715
        default=None,
716
717
        description="KVTransfer parameters used for disaggregated serving.",
    )
718

719
    vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
720
        default=None,
721
        description=(
722
            "Additional request parameters with (list of) string or "
723
724
            "numeric values, used by custom extensions."
        ),
725
726
    )

727
    # --8<-- [end:chat-completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
728

729
730
731
732
733
    # Default sampling parameters for chat completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
734
        "top_k": 0,
735
736
737
738
        "min_p": 0.0,
    }

    def to_beam_search_params(
739
740
        self, max_tokens: int, default_sampling_params: dict
    ) -> BeamSearchParams:
741
        n = self.n if self.n is not None else 1
742
743
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
744
745
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
746
747
748
749
750
751

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
752
            length_penalty=self.length_penalty,
753
754
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
755

756
    def to_sampling_params(
757
        self,
758
        max_tokens: int,
759
        logits_processor_pattern: str | None,
760
        default_sampling_params: dict,
761
    ) -> SamplingParams:
762
763
764
765
766
767
768
769
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
770
771
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
772
773
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
774
775
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
776
777
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
778
779
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
780
781
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
782
783
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
784

785
786
787
788
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.top_logprobs

789
        response_format = self.response_format
790
        if response_format is not None:
791
792
793
794
795
796
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
            if response_format.type == "json_object":
                self.structured_outputs.json_object = True
            elif response_format.type == "json_schema":
                json_schema = response_format.json_schema
                assert json_schema is not None
                self.structured_outputs.json = json_schema.json_schema
            elif response_format.type == "structural_tag":
                structural_tag = response_format
                assert structural_tag is not None and isinstance(
                    structural_tag,
                    (
                        LegacyStructuralTagResponseFormat,
                        StructuralTagResponseFormat,
                    ),
                )
                s_tag_obj = structural_tag.model_dump(by_alias=True)
                self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
814

815
816
817
818
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
819
        return SamplingParams.from_optional(
820
821
822
            n=self.n,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
823
824
825
826
827
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
828
            seed=self.seed,
829
830
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
831
            logprobs=self.top_logprobs if self.logprobs else None,
832
            prompt_logprobs=prompt_logprobs,
833
            ignore_eos=self.ignore_eos,
834
            max_tokens=max_tokens,
835
            min_tokens=self.min_tokens,
836
837
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
838
839
840
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
841
            include_stop_str_in_output=self.include_stop_str_in_output,
842
            truncate_prompt_tokens=self.truncate_prompt_tokens,
843
844
845
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
846
            structured_outputs=self.structured_outputs,
Robert Shaw's avatar
Robert Shaw committed
847
            logit_bias=self.logit_bias,
848
            bad_words=self.bad_words,
849
            allowed_token_ids=self.allowed_token_ids,
850
            extra_args=extra_args or None,
851
            skip_clone=True,  # Created fresh per request, safe to skip clone
852
        )
853

854
    @model_validator(mode="before")
855
    @classmethod
856
857
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
858
859
860
861
            raise VLLMValidationError(
                "Stream options can only be defined when `stream=True`.",
                parameter="stream_options",
            )
862
863
864
865
866
867
868

        return data

    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
869
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
870
871
872
                raise VLLMValidationError(
                    "`prompt_logprobs` are not available when `stream=True`.",
                    parameter="prompt_logprobs",
873
                )
874

875
            if prompt_logprobs < 0 and prompt_logprobs != -1:
876
877
878
879
880
                raise VLLMValidationError(
                    "`prompt_logprobs` must be a positive value or -1.",
                    parameter="prompt_logprobs",
                    value=prompt_logprobs,
                )
881
        if (top_logprobs := data.get("top_logprobs")) is not None:
882
            if top_logprobs < 0 and top_logprobs != -1:
883
884
885
886
887
                raise VLLMValidationError(
                    "`top_logprobs` must be a positive value or -1.",
                    parameter="top_logprobs",
                    value=top_logprobs,
                )
888

889
            if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
890
891
892
                raise VLLMValidationError(
                    "when using `top_logprobs`, `logprobs` must be set to true.",
                    parameter="top_logprobs",
893
894
895
                )

        return data
896

897
898
    @model_validator(mode="before")
    @classmethod
899
    def check_structured_outputs_count(cls, data):
900
901
902
        if isinstance(data, ValueError):
            raise data

903
        if data.get("structured_outputs", None) is None:
904
905
            return data

906
        structured_outputs_kwargs = data["structured_outputs"]
907
908
        count = sum(
            structured_outputs_kwargs.get(k) is not None
909
910
            for k in ("json", "regex", "choice")
        )
911
912
        # you can only use one kind of constraints for structured outputs
        if count > 1:
913
            raise ValueError(
914
                "You can only use one kind of constraints for structured "
915
916
                "outputs ('json', 'regex' or 'choice')."
            )
917
918
        # you can only either use structured outputs or tools, not both
        if count > 1 and data.get("tool_choice", "none") not in (
919
920
921
            "none",
            "auto",
            "required",
922
        ):
923
            raise ValueError(
924
                "You can only either use constraints for structured outputs "
925
926
                "or tools, not both."
            )
927
928
929
930
        return data

    @model_validator(mode="before")
    @classmethod
931
932
933
    def check_tool_usage(cls, data):
        # if "tool_choice" is not specified but tools are provided,
        # default to "auto" tool_choice
934
        if "tool_choice" not in data and data.get("tools"):
935
936
            data["tool_choice"] = "auto"

937
        # if "tool_choice" is "none" -- no validation is needed for tools
938
939
940
        if "tool_choice" in data and data["tool_choice"] == "none":
            return data

941
        # if "tool_choice" is specified -- validation
942
        if "tool_choice" in data and data["tool_choice"] is not None:
943
            # ensure that if "tool choice" is specified, tools are present
944
            if "tools" not in data or data["tools"] is None:
945
                raise ValueError("When using `tool_choice`, `tools` must be set.")
946
947

            # make sure that tool choice is either a named tool
948
            # OR that it's set to "auto" or "required"
949
950
951
            if data["tool_choice"] not in ["auto", "required"] and not isinstance(
                data["tool_choice"], dict
            ):
952
                raise ValueError(
953
954
955
                    f"Invalid value for `tool_choice`: {data['tool_choice']}! "
                    'Only named tools, "none", "auto" or "required" '
                    "are supported."
956
                )
957

958
959
960
            # if tool_choice is "required" but the "tools" list is empty,
            # override the data to behave like "none" to align with
            # OpenAI’s behavior.
961
962
963
964
965
            if (
                data["tool_choice"] == "required"
                and isinstance(data["tools"], list)
                and len(data["tools"]) == 0
            ):
966
967
968
969
                data["tool_choice"] = "none"
                del data["tools"]
                return data

970
971
            # ensure that if "tool_choice" is specified as an object,
            # it matches a valid tool
972
973
            correct_usage_message = (
                'Correct usage: `{"type": "function",'
974
                ' "function": {"name": "my_function"}}`'
975
            )
976
977
            if isinstance(data["tool_choice"], dict):
                valid_tool = False
978
979
                function = data["tool_choice"].get("function")
                if not isinstance(function, dict):
980
                    raise ValueError(
981
                        f"Invalid value for `function`: `{function}` in "
982
983
                        f"`tool_choice`! {correct_usage_message}"
                    )
984
                if "name" not in function:
985
986
987
988
                    raise ValueError(
                        f"Expected field `name` in `function` in "
                        f"`tool_choice`! {correct_usage_message}"
                    )
989
                function_name = function["name"]
990
                if not isinstance(function_name, str) or len(function_name) == 0:
991
                    raise ValueError(
992
                        f"Invalid `name` in `function`: `{function_name}`"
993
994
                        f" in `tool_choice`! {correct_usage_message}"
                    )
995
                for tool in data["tools"]:
996
                    if tool["function"]["name"] == function_name:
997
998
999
1000
1001
                        valid_tool = True
                        break
                if not valid_tool:
                    raise ValueError(
                        "The tool specified in `tool_choice` does not match any"
1002
1003
                        " of the specified `tools`"
                    )
1004
1005
        return data

1006
1007
1008
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1009
1010
1011
1012
1013
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1014
1015
        return data

1016
1017
1018
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
1019
1020
1021
1022
1023
1024
        if data.get("cache_salt") is not None and (
            not isinstance(data["cache_salt"], str) or not data["cache_salt"]
        ):
            raise ValueError(
                "Parameter 'cache_salt' must be a non-empty string if provided."
            )
1025
1026
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1027

1028
class CompletionRequest(OpenAIBaseModel):
1029
1030
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
1031
1032
1033
1034
1035
1036
1037
    model: str | None = None
    prompt: list[int] | list[list[int]] | str | list[str] | None = None
    echo: bool | None = False
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: int | None = None
    max_tokens: int | None = 16
1038
    n: int = 1
1039
1040
1041
1042
1043
1044
1045
1046
1047
    presence_penalty: float | None = 0.0
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    suffix: str | None = None
    temperature: float | None = None
    top_p: float | None = None
    user: str | None = None
1048

1049
    # --8<-- [start:completion-sampling-params]
1050
    use_beam_search: bool = False
1051
1052
1053
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
1054
    length_penalty: float = 1.0
1055
    stop_token_ids: list[int] | None = []
1056
1057
1058
1059
1060
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
1061
1062
1063
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    allowed_token_ids: list[int] | None = None
    prompt_logprobs: int | None = None
1064
    # --8<-- [end:completion-sampling-params]
1065

1066
    # --8<-- [start:completion-extra-params]
1067
    prompt_embeds: bytes | list[bytes] | None = None
1068
1069
    add_special_tokens: bool = Field(
        default=True,
1070
        description=(
1071
            "If true (the default), special tokens (e.g. BOS) will be added to "
1072
1073
            "the prompt."
        ),
1074
    )
1075
    response_format: AnyResponseFormat | None = Field(
1076
        default=None,
1077
1078
1079
1080
1081
        description=(
            "Similar to chat completion, this parameter specifies the format "
            "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
            ", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
        ),
1082
    )
1083
    structured_outputs: StructuredOutputsParams | None = Field(
1084
        default=None,
1085
        description="Additional kwargs for structured outputs",
1086
    )
1087
1088
1089
1090
1091
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1092
1093
            "if the served model does not use priority scheduling."
        ),
1094
    )
1095
    request_id: str = Field(
1096
        default_factory=random_uuid,
1097
1098
1099
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1100
1101
            "through out the inference process and return in response."
        ),
1102
    )
1103
    logits_processors: LogitsProcessors | None = Field(
1104
1105
1106
1107
1108
1109
1110
1111
1112
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
1113
1114
1115
            "{'param': 'value'}}."
        ),
    )
1116

1117
    return_tokens_as_token_ids: bool | None = Field(
1118
1119
1120
1121
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
1122
1123
1124
            "that are not JSON-encodable can be identified."
        ),
    )
1125
    return_token_ids: bool | None = Field(
1126
1127
1128
1129
1130
1131
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
1132
1133
1134
            "need to map generated text back to input tokens."
        ),
    )
1135

1136
    cache_salt: str | None = Field(
1137
1138
1139
1140
1141
1142
1143
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
1144
            "to 256 bit)."
1145
1146
        ),
    )
1147

1148
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
1149
        default=None,
1150
1151
        description="KVTransfer parameters used for disaggregated serving.",
    )
Robert Shaw's avatar
Robert Shaw committed
1152

1153
    vllm_xargs: dict[str, str | int | float] | None = Field(
1154
        default=None,
1155
1156
1157
1158
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
1159
1160
    )

1161
    # --8<-- [end:completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
1162

1163
1164
1165
1166
1167
    # Default sampling parameters for completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
1168
        "top_k": 0,
1169
1170
1171
1172
        "min_p": 0.0,
    }

    def to_beam_search_params(
1173
1174
        self,
        max_tokens: int,
1175
        default_sampling_params: dict | None = None,
1176
1177
1178
    ) -> BeamSearchParams:
        if default_sampling_params is None:
            default_sampling_params = {}
1179
        n = self.n if self.n is not None else 1
1180
1181
1182

        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get("temperature", 1.0)
1183
1184
1185
1186
1187
1188

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
1189
            length_penalty=self.length_penalty,
1190
1191
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
1192

1193
    def to_sampling_params(
1194
        self,
1195
        max_tokens: int,
1196
1197
        logits_processor_pattern: str | None,
        default_sampling_params: dict | None = None,
1198
    ) -> SamplingParams:
1199
1200
        if default_sampling_params is None:
            default_sampling_params = {}
1201

1202
1203
1204
1205
1206
1207
1208
1209
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
1210
1211
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
1212
1213
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
1214
1215
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
1216
1217
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
1218
1219
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
1220
1221
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
1222
1223
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
1224

1225
1226
1227
1228
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

1229
1230
        echo_without_generation = self.echo and self.max_tokens == 0

1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
        response_format = self.response_format
        if response_format is not None:
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
            if response_format.type == "json_object":
                self.structured_outputs.json_object = True
            elif response_format.type == "json_schema":
                json_schema = response_format.json_schema
1243
                assert json_schema is not None
1244
1245
1246
                self.structured_outputs.json = json_schema.json_schema
            elif response_format.type == "structural_tag":
                structural_tag = response_format
1247
                assert structural_tag is not None and isinstance(
1248
1249
1250
1251
1252
                    structural_tag,
                    (
                        LegacyStructuralTagResponseFormat,
                        StructuralTagResponseFormat,
                    ),
1253
1254
                )
                s_tag_obj = structural_tag.model_dump(by_alias=True)
1255
                self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
1256

1257
1258
1259
1260
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
1261
        return SamplingParams.from_optional(
1262
1263
1264
            n=self.n,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
1265
1266
1267
1268
1269
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
1270
            seed=self.seed,
1271
1272
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
1273
            logprobs=self.logprobs,
1274
            ignore_eos=self.ignore_eos,
1275
            max_tokens=max_tokens if not echo_without_generation else 1,
1276
            min_tokens=self.min_tokens,
1277
            prompt_logprobs=prompt_logprobs,
1278
            skip_special_tokens=self.skip_special_tokens,
1279
            spaces_between_special_tokens=self.spaces_between_special_tokens,
1280
            include_stop_str_in_output=self.include_stop_str_in_output,
1281
1282
1283
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
1284
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1285
1286
1287
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
1288
            structured_outputs=self.structured_outputs,
1289
            logit_bias=self.logit_bias,
Robert Shaw's avatar
Robert Shaw committed
1290
            allowed_token_ids=self.allowed_token_ids,
1291
            extra_args=extra_args or None,
1292
            skip_clone=True,  # Created fresh per request, safe to skip clone
1293
        )
1294

1295
1296
    @model_validator(mode="before")
    @classmethod
1297
    def check_structured_outputs_count(cls, data):
1298
        if data.get("structured_outputs", None) is None:
1299
1300
            return data

1301
        structured_outputs_kwargs = data["structured_outputs"]
1302
1303
        count = sum(
            structured_outputs_kwargs.get(k) is not None
1304
1305
            for k in ("json", "regex", "choice")
        )
1306
        if count > 1:
1307
            raise VLLMValidationError(
1308
                "You can only use one kind of constraints for structured "
1309
1310
                "outputs ('json', 'regex' or 'choice').",
                parameter="structured_outputs",
1311
            )
1312
1313
        return data

1314
1315
1316
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
1317
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
1318
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
1319
1320
1321
                raise VLLMValidationError(
                    "`prompt_logprobs` are not available when `stream=True`.",
                    parameter="prompt_logprobs",
1322
                )
1323

1324
            if prompt_logprobs < 0 and prompt_logprobs != -1:
1325
1326
1327
1328
1329
                raise VLLMValidationError(
                    "`prompt_logprobs` must be a positive value or -1.",
                    parameter="prompt_logprobs",
                    value=prompt_logprobs,
                )
1330
        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
1331
1332
1333
1334
1335
            raise VLLMValidationError(
                "`logprobs` must be a positive value.",
                parameter="logprobs",
                value=logprobs,
            )
1336

1337
1338
        return data

1339
1340
1341
1342
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
1343
1344
1345
1346
            raise VLLMValidationError(
                "Stream options can only be defined when `stream=True`.",
                parameter="stream_options",
            )
1347

1348
1349
        return data

1350
1351
1352
    @model_validator(mode="before")
    @classmethod
    def validate_prompt_and_prompt_embeds(cls, data):
1353
1354
1355
        prompt = data.get("prompt")
        prompt_embeds = data.get("prompt_embeds")

1356
1357
1358
1359
        prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
        embeds_is_empty = prompt_embeds is None or (
            isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
        )
1360
1361

        if prompt_is_empty and embeds_is_empty:
1362
            raise ValueError(
1363
1364
1365
                "Either prompt or prompt_embeds must be provided and non-empty."
            )

1366
1367
        return data

1368
1369
1370
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
1371
1372
1373
1374
1375
1376
        if data.get("cache_salt") is not None and (
            not isinstance(data["cache_salt"], str) or not data["cache_salt"]
        ):
            raise ValueError(
                "Parameter 'cache_salt' must be a non-empty string if provided."
            )
1377
1378
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1379

1380
class CompletionLogProbs(OpenAIBaseModel):
1381
    text_offset: list[int] = Field(default_factory=list)
1382
    token_logprobs: list[float | None] = Field(default_factory=list)
1383
    tokens: list[str] = Field(default_factory=list)
1384
    top_logprobs: list[dict[str, float] | None] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
1385
1386


1387
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1388
1389
    index: int
    text: str
1390
1391
1392
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1393
1394
1395
1396
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1397
1398
            "including encountering the EOS token"
        ),
1399
    )
1400
1401
1402
    token_ids: list[int] | None = None  # For response
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None  # For prompt
Zhuohan Li's avatar
Zhuohan Li committed
1403
1404


1405
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1406
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
1407
    object: Literal["text_completion"] = "text_completion"
Zhuohan Li's avatar
Zhuohan Li committed
1408
1409
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1410
    choices: list[CompletionResponseChoice]
1411
1412
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1413
    usage: UsageInfo
1414
1415

    # vLLM-specific fields that are not in OpenAI spec
1416
    kv_transfer_params: dict[str, Any] | None = Field(
1417
1418
        default=None, description="KVTransfer parameters."
    )
Zhuohan Li's avatar
Zhuohan Li committed
1419
1420


1421
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1422
1423
    index: int
    text: str
1424
1425
1426
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1427
1428
1429
1430
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1431
1432
            "including encountering the EOS token"
        ),
1433
    )
1434
1435
    # not part of the OpenAI spec but for tracing the tokens
    # prompt tokens is put into choice to align with CompletionResponseChoice
1436
1437
    prompt_token_ids: list[int] | None = None
    token_ids: list[int] | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1438
1439


1440
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1441
1442
1443
1444
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1445
    choices: list[CompletionResponseStreamChoice]
1446
    usage: UsageInfo | None = Field(default=None)
1447
1448


1449
1450
1451
1452
1453
1454
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
1455
    id: str = Field(default_factory=make_tool_call_id)
1456
1457
1458
1459
    type: Literal["function"] = "function"
    function: FunctionCall


1460
class DeltaFunctionCall(BaseModel):
1461
1462
    name: str | None = None
    arguments: str | None = None
1463
1464
1465
1466


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
1467
1468
    id: str | None = None
    type: Literal["function"] | None = None
1469
    index: int
1470
    function: DeltaFunctionCall | None = None
1471
1472
1473
1474
1475
1476
1477


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
1478
    tool_calls: list[ToolCall]
1479
1480
1481

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
1482
    content: str | None = None
1483
1484


1485
class ChatMessage(OpenAIBaseModel):
1486
    role: str
1487
1488
1489
1490
1491
    content: str | None = None
    refusal: str | None = None
    annotations: OpenAIAnnotation | None = None
    audio: OpenAIChatCompletionAudio | None = None
    function_call: FunctionCall | None = None
1492
    tool_calls: list[ToolCall] = Field(default_factory=list)
1493

1494
    # vLLM-specific fields that are not in OpenAI spec
1495
    reasoning: str | None = None
1496
    reasoning_content: str | None = None
1497
1498
1499
1500
1501
1502
1503
    """Deprecated: use `reasoning` instead."""

    @model_validator(mode="after")
    def handle_deprecated_reasoning_content(self):
        """Copy reasoning to reasoning_content for backward compatibility."""
        self.reasoning_content = self.reasoning
        return self
1504

1505

1506
1507
1508
class ChatCompletionLogProb(OpenAIBaseModel):
    token: str
    logprob: float = -9999.0
1509
    bytes: list[int] | None = None
1510
1511
1512


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
1513
1514
    # Workaround: redefine fields name cache so that it's not
    # shared with the super class.
1515
    field_names: ClassVar[set[str] | None] = None
1516
    top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
1517
1518
1519


class ChatCompletionLogProbs(OpenAIBaseModel):
1520
    content: list[ChatCompletionLogProbsContent] | None = None
1521
1522


1523
class ChatCompletionResponseChoice(OpenAIBaseModel):
1524
1525
    index: int
    message: ChatMessage
1526
    logprobs: ChatCompletionLogProbs | None = None
1527
    # per OpenAI spec this is the default
1528
    finish_reason: str | None = "stop"
1529
    # not part of the OpenAI spec but included in vLLM for legacy reasons
1530
    stop_reason: int | str | None = None
1531
1532
    # not part of the OpenAI spec but is useful for tracing the tokens
    # in agent scenarios
1533
    token_ids: list[int] | None = None
1534
1535


1536
class ChatCompletionResponse(OpenAIBaseModel):
1537
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1538
    object: Literal["chat.completion"] = "chat.completion"
1539
1540
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1541
    choices: list[ChatCompletionResponseChoice]
1542
1543
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
1544
    usage: UsageInfo
1545
1546

    # vLLM-specific fields that are not in OpenAI spec
1547
1548
1549
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None
    kv_transfer_params: dict[str, Any] | None = Field(
1550
1551
        default=None, description="KVTransfer parameters."
    )
1552
1553


1554
class DeltaMessage(OpenAIBaseModel):
1555
1556
    role: str | None = None
    content: str | None = None
1557
    reasoning: str | None = None
1558
    reasoning_content: str | None = None
1559
    """Deprecated: use `reasoning` instead."""
1560
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
1561

1562
1563
1564
1565
1566
1567
    @model_validator(mode="after")
    def handle_deprecated_reasoning_content(self):
        """Copy reasoning to reasoning_content for backward compatibility."""
        self.reasoning_content = self.reasoning
        return self

1568

1569
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
1570
1571
    index: int
    delta: DeltaMessage
1572
1573
1574
    logprobs: ChatCompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = None
1575
    # not part of the OpenAI spec but for tracing the tokens
1576
    token_ids: list[int] | None = None
1577
1578


1579
class ChatCompletionStreamResponse(OpenAIBaseModel):
1580
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1581
    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
1582
1583
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1584
    choices: list[ChatCompletionResponseStreamChoice]
1585
    usage: UsageInfo | None = Field(default=None)
1586
    # not part of the OpenAI spec but for tracing the tokens
1587
    prompt_token_ids: list[int] | None = None
1588
1589


1590
1591
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
1592
1593
    finish_reason: str | None = None
    stop_reason: int | str | None = None
1594
1595
1596
1597
1598
1599
1600
1601


class TranscriptionStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
    object: Literal["transcription.chunk"] = "transcription.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranscriptionResponseStreamChoice]
1602
    usage: UsageInfo | None = Field(default=None)
1603
1604


1605
1606
class InputTokensDetails(OpenAIBaseModel):
    cached_tokens: int
1607
1608
    input_tokens_per_turn: list[int] = Field(default_factory=list)
    cached_tokens_per_turn: list[int] = Field(default_factory=list)
1609
1610
1611


class OutputTokensDetails(OpenAIBaseModel):
1612
1613
    reasoning_tokens: int = 0
    tool_output_tokens: int = 0
1614
1615
    output_tokens_per_turn: list[int] = Field(default_factory=list)
    tool_output_tokens_per_turn: list[int] = Field(default_factory=list)
1616
1617
1618
1619
1620
1621
1622
1623


class ResponseUsage(OpenAIBaseModel):
    input_tokens: int
    input_tokens_details: InputTokensDetails
    output_tokens: int
    output_tokens_details: OutputTokensDetails
    total_tokens: int
1624
1625


1626
1627
1628
1629
1630
1631
def serialize_message(msg):
    """
    Serializes a single message
    """
    if isinstance(msg, dict):
        return msg
1632
    elif hasattr(msg, "to_dict"):
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
        return msg.to_dict()
    else:
        # fallback to pyandic dump
        return msg.model_dump_json()


def serialize_messages(msgs):
    """
    Serializes multiple messages
    """
    return [serialize_message(msg) for msg in msgs] if msgs else None


1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
class ResponseRawMessageAndToken(OpenAIBaseModel):
    """Class to show the raw message.
    If message / tokens diverge, tokens is the source of truth"""

    message: str
    tokens: list[int]
    type: Literal["raw_message_tokens"] = "raw_message_tokens"


ResponseInputOutputMessage: TypeAlias = (
    list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken]
)


1660
1661
1662
1663
class ResponsesResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
    created_at: int = Field(default_factory=lambda: int(time.time()))
    # error: Optional[ResponseError] = None
1664
1665
1666
    incomplete_details: IncompleteDetails | None = None
    instructions: str | None = None
    metadata: Metadata | None = None
1667
1668
    model: str
    object: Literal["response"] = "response"
1669
    output: list[ResponseOutputItem]
1670
1671
1672
1673
1674
1675
1676
    parallel_tool_calls: bool
    temperature: float
    tool_choice: ToolChoice
    tools: list[Tool]
    top_p: float
    background: bool
    max_output_tokens: int
1677
1678
1679
1680
    max_tool_calls: int | None = None
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
1681
1682
    service_tier: Literal["auto", "default", "flex", "scale", "priority"]
    status: ResponseStatus
1683
1684
    text: ResponseTextConfig | None = None
    top_logprobs: int | None = None
1685
    truncation: Literal["auto", "disabled"]
1686
1687
    usage: ResponseUsage | None = None
    user: str | None = None
1688

1689
    # --8<-- [start:responses-response-extra-params]
1690
1691
1692
    # These are populated when enable_response_messages is set to True
    # NOTE: custom serialization is needed
    # see serialize_input_messages and serialize_output_messages
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
    input_messages: ResponseInputOutputMessage | None = Field(
        default=None,
        description=(
            "If enable_response_messages, we can show raw token input to model."
        ),
    )
    output_messages: ResponseInputOutputMessage | None = Field(
        default=None,
        description=(
            "If enable_response_messages, we can show raw token output of model."
        ),
    )
    # --8<-- [end:responses-response-extra-params]
1706
1707
1708
1709
1710
1711

    # NOTE: openAI harmony doesn't serialize TextContent properly,
    # TODO: this fixes for TextContent, but need to verify for tools etc
    # https://github.com/openai/harmony/issues/78
    @field_serializer("output_messages", when_used="json")
    def serialize_output_messages(self, msgs, _info):
1712
        return serialize_messages(msgs)
1713
1714
1715
1716
1717

    # NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it
    # https://github.com/openai/harmony/issues/78
    @field_serializer("input_messages", when_used="json")
    def serialize_input_messages(self, msgs, _info):
1718
        return serialize_messages(msgs)
1719

1720
1721
1722
1723
1724
1725
1726
1727
1728
    @classmethod
    def from_request(
        cls,
        request: ResponsesRequest,
        sampling_params: SamplingParams,
        model_name: str,
        created_time: int,
        output: list[ResponseOutputItem],
        status: ResponseStatus,
1729
        usage: ResponseUsage | None = None,
1730
1731
        input_messages: ResponseInputOutputMessage | None = None,
        output_messages: ResponseInputOutputMessage | None = None,
1732
    ) -> "ResponsesResponse":
1733
        incomplete_details: IncompleteDetails | None = None
1734
1735
        if status == "incomplete":
            incomplete_details = IncompleteDetails(reason="max_output_tokens")
1736
1737
1738
        # TODO: implement the other reason for incomplete_details,
        # which is content_filter
        # incomplete_details = IncompleteDetails(reason='content_filter')
1739
1740
1741
        return cls(
            id=request.request_id,
            created_at=created_time,
1742
            incomplete_details=incomplete_details,
1743
1744
1745
1746
            instructions=request.instructions,
            metadata=request.metadata,
            model=model_name,
            output=output,
1747
1748
            input_messages=input_messages,
            output_messages=output_messages,
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
            parallel_tool_calls=request.parallel_tool_calls,
            temperature=sampling_params.temperature,
            tool_choice=request.tool_choice,
            tools=request.tools,
            top_p=sampling_params.top_p,
            background=request.background,
            max_output_tokens=sampling_params.max_tokens,
            max_tool_calls=request.max_tool_calls,
            previous_response_id=request.previous_response_id,
            prompt=request.prompt,
            reasoning=request.reasoning,
            service_tier=request.service_tier,
            status=status,
            text=request.text,
            top_logprobs=sampling_params.logprobs,
            truncation=request.truncation,
            user=request.user,
            usage=usage,
        )


1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartDoneEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.done"]
    """The type of the event. Always `response.reasoning_part.done`."""


# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartAddedEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.added"]
    """The type of the event. Always `response.reasoning_part.added`."""


1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
# vLLM Streaming Events
# Note: we override the response type with the vLLM ResponsesResponse type
class ResponseCompletedEvent(OpenAIResponseCompletedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseCreatedEvent(OpenAIResponseCreatedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseInProgressEvent(OpenAIResponseInProgressEvent):
    response: ResponsesResponse  # type: ignore[override]


1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
StreamingResponsesResponse: TypeAlias = (
    ResponseCreatedEvent
    | ResponseInProgressEvent
    | ResponseCompletedEvent
    | ResponseOutputItemAddedEvent
    | ResponseOutputItemDoneEvent
    | ResponseContentPartAddedEvent
    | ResponseContentPartDoneEvent
    | ResponseReasoningTextDeltaEvent
    | ResponseReasoningTextDoneEvent
    | ResponseReasoningPartAddedEvent
    | ResponseReasoningPartDoneEvent
    | ResponseCodeInterpreterCallInProgressEvent
    | ResponseCodeInterpreterCallCodeDeltaEvent
    | ResponseWebSearchCallInProgressEvent
    | ResponseWebSearchCallSearchingEvent
    | ResponseWebSearchCallCompletedEvent
    | ResponseCodeInterpreterCallCodeDoneEvent
    | ResponseCodeInterpreterCallInterpretingEvent
    | ResponseCodeInterpreterCallCompletedEvent
1848
1849
1850
1851
    | ResponseMcpCallArgumentsDeltaEvent
    | ResponseMcpCallArgumentsDoneEvent
    | ResponseMcpCallInProgressEvent
    | ResponseMcpCallCompletedEvent
1852
)
1853

1854

1855
class TokenizeCompletionRequest(OpenAIBaseModel):
1856
    model: str | None = None
1857
1858
    prompt: str

1859
1860
1861
1862
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
1863
1864
            "the prompt."
        ),
1865
    )
1866
    return_token_strs: bool | None = Field(
1867
        default=False,
1868
1869
1870
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
1871
    )
1872
1873
1874


class TokenizeChatRequest(OpenAIBaseModel):
1875
    model: str | None = None
1876
    messages: list[ChatCompletionMessageParam]
1877

1878
1879
    add_generation_prompt: bool = Field(
        default=True,
1880
1881
1882
1883
1884
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
1885
    )
1886
    return_token_strs: bool | None = Field(
1887
        default=False,
1888
1889
1890
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
1891
    )
1892
1893
    continue_final_message: bool = Field(
        default=False,
1894
1895
1896
1897
1898
1899
1900
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
1901
1902
1903
1904
1905
1906
1907
1908
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
1909
1910
            "default)."
        ),
1911
    )
1912
    chat_template: str | None = Field(
1913
1914
1915
1916
1917
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1918
1919
            "does not define one."
        ),
1920
    )
1921
    chat_template_kwargs: dict[str, Any] | None = Field(
1922
        default=None,
1923
1924
        description=(
            "Additional keyword args to pass to the template renderer. "
1925
1926
            "Will be accessible by the chat template."
        ),
1927
    )
1928
    mm_processor_kwargs: dict[str, Any] | None = Field(
1929
1930
1931
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
1932
    tools: list[ChatCompletionToolsParam] | None = Field(
1933
1934
1935
        default=None,
        description=("A list of tools the model may call."),
    )
1936

1937
1938
1939
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1940
1941
1942
1943
1944
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1945
1946
        return data

1947

1948
TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest
1949
1950
1951
1952
1953


class TokenizeResponse(OpenAIBaseModel):
    count: int
    max_model_len: int
1954
    tokens: list[int]
1955
    token_strs: list[str] | None = None
1956
1957
1958


class DetokenizeRequest(OpenAIBaseModel):
1959
    model: str | None = None
1960
    tokens: list[int]
1961
1962
1963
1964


class DetokenizeResponse(OpenAIBaseModel):
    prompt: str
1965
1966


1967
1968
class TokenizerInfoResponse(OpenAIBaseModel):
    """
1969
    Response containing tokenizer configuration
1970
1971
1972
1973
1974
1975
1976
    equivalent to tokenizer_config.json
    """

    model_config = ConfigDict(extra="allow")
    tokenizer_class: str


1977
class LoadLoRAAdapterRequest(BaseModel):
1978
1979
1980
1981
    lora_name: str
    lora_path: str


1982
class UnloadLoRAAdapterRequest(BaseModel):
1983
    lora_name: str
1984
    lora_int_id: int | None = Field(default=None)
1985
1986
1987


## Protocols for Audio
1988
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
1989
1990
1991
1992


class TranscriptionRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
1993
    # https://platform.openai.com/docs/api-reference/audio/createTranscription
1994
1995
1996
1997
1998
1999
2000

    file: UploadFile
    """
    The audio file object (not file name) to transcribe, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2001
    model: str | None = None
2002
2003
2004
    """ID of the model to use.
    """

2005
    language: str | None = None
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
    """The language of the input audio.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy and latency.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    ## TODO (varun) : Support if set to 0, certain thresholds are met !!

2029
    timestamp_granularities: list[Literal["word", "segment"]] = Field(
2030
2031
        alias="timestamp_granularities[]", default=[]
    )
2032
2033
2034
2035
2036
2037
2038
2039
    """The timestamp granularities to populate for this transcription.

    `response_format` must be set `verbose_json` to use timestamp granularities.
    Either or both of these options are supported: `word`, or `segment`. Note:
    There is no additional latency for segment timestamps, but generating word
    timestamps incurs additional latency.
    """

2040
    stream: bool | None = False
2041
    """When set, it will enable output to be streamed in a similar fashion
2042
    as the Chat Completion endpoint.
2043
    """
2044
    # --8<-- [start:transcription-extra-params]
2045
    # Flattened stream option to simplify form data.
2046
2047
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
2048

2049
    vllm_xargs: dict[str, str | int | float] | None = Field(
2050
        default=None,
2051
2052
2053
2054
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
2055
    )
2056
    # --8<-- [end:transcription-extra-params]
2057

2058
    to_language: str | None = None
2059
2060
    """The language of the output audio we transcribe to.

2061
    Please note that this is not currently used by supported models at this
2062
2063
2064
    time, but it is a placeholder for future use, matching translation api.
    """

2065
    # --8<-- [start:transcription-sampling-params]
2066
2067
2068
2069
2070
2071
2072
2073
2074
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """

2075
    top_p: float | None = None
2076
    """Enables nucleus (top-p) sampling, where tokens are selected from the
2077
2078
2079
    smallest possible set whose cumulative probability exceeds `p`.
    """

2080
    top_k: int | None = None
2081
2082
    """Limits sampling to the `k` most probable tokens at each step."""

2083
    min_p: float | None = None
2084
    """Filters out tokens with a probability lower than `min_p`, ensuring a
2085
2086
2087
    minimum likelihood threshold during sampling.
    """

2088
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2089
2090
    """The seed to use for sampling."""

2091
    frequency_penalty: float | None = 0.0
2092
2093
    """The frequency penalty to use for sampling."""

2094
    repetition_penalty: float | None = None
2095
2096
    """The repetition penalty to use for sampling."""

2097
    presence_penalty: float | None = 0.0
2098
    """The presence penalty to use for sampling."""
2099
2100
2101

    max_completion_tokens: int | None = None
    """The maximum number of tokens to generate."""
2102
    # --8<-- [end:transcription-sampling-params]
2103

2104
2105
    # Default sampling parameters for transcription requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
2106
2107
2108
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
2109
        "top_k": 0,
2110
        "min_p": 0.0,
2111
2112
2113
    }

    def to_sampling_params(
2114
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2115
    ) -> SamplingParams:
2116
2117
2118
2119
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
2120

2121
2122
2123
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2124
2125
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2126
2127
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
2128
2129
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
2130
2131
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
2132
2133
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
2134
2135
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
2136
2137
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
2138
2139
2140
2141

        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )

        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=repetition_penalty,
            presence_penalty=self.presence_penalty,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
            extra_args=self.vllm_xargs,
2159
            skip_clone=True,  # Created fresh per request, safe to skip clone
2160
        )
2161
2162
2163

    @model_validator(mode="before")
    @classmethod
2164
2165
2166
2167
2168
2169
2170
    def validate_transcription_request(cls, data):
        if isinstance(data.get("file"), str):
            raise HTTPException(
                status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
                detail="Expected 'file' to be a file-like object, not 'str'.",
            )

2171
2172
2173
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2174
2175
2176
2177
2178
2179
2180
2181
2182
            # Find which specific stream option was set
            invalid_param = next(
                (so for so in stream_opts if data.get(so, False)),
                "stream_include_usage",
            )
            raise VLLMValidationError(
                "Stream options can only be defined when `stream=True`.",
                parameter=invalid_param,
            )
2183
2184

        return data
2185
2186
2187


# Transcription response objects
2188
2189
2190
2191
2192
class TranscriptionUsageAudio(OpenAIBaseModel):
    type: Literal["duration"] = "duration"
    seconds: int


2193
2194
2195
class TranscriptionResponse(OpenAIBaseModel):
    text: str
    """The transcribed text."""
2196
    usage: TranscriptionUsageAudio
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213


class TranscriptionWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranscriptionSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

2214
    avg_logprob: float | None = None
2215
2216
2217
2218
2219
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

2220
    compression_ratio: float | None = None
2221
2222
2223
2224
2225
2226
2227
2228
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

2229
    no_speech_prob: float | None = None
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

2248
    tokens: list[int]
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
    """Array of token IDs for the text content."""


class TranscriptionResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The transcribed text."""

2262
    segments: list[TranscriptionSegment] | None = None
2263
2264
    """Segments of the transcribed text and their corresponding details."""

2265
    words: list[TranscriptionWord] | None = None
2266
    """Extracted words and their corresponding timestamps."""
2267
2268


2269
2270
2271
2272
2273
TranscriptionResponseVariant: TypeAlias = (
    TranscriptionResponse | TranscriptionResponseVerbose
)


2274
2275
class TranslationResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
2276
2277
    finish_reason: str | None = None
    stop_reason: int | str | None = None
2278
2279
2280
2281
2282
2283
2284
2285


class TranslationStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
    object: Literal["translation.chunk"] = "translation.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranslationResponseStreamChoice]
2286
    usage: UsageInfo | None = Field(default=None)
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298


class TranslationRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/audio/createTranslation

    file: UploadFile
    """
    The audio file object (not file name) to translate, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2299
    model: str | None = None
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
    """ID of the model to use.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    # TODO support additional sampling parameters
    # --8<-- [start:translation-sampling-params]
2319
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2320
2321
    """The seed to use for sampling."""

2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """
    # --8<-- [end:translation-sampling-params]

    # --8<-- [start:translation-extra-params]
2333
    language: str | None = None
2334
2335
2336
2337
2338
2339
2340
    """The language of the input audio we translate from.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy.
    """

2341
    to_language: str | None = None
2342
2343
2344
2345
2346
2347
2348
    """The language of the input audio we translate to.

    Please note that this is not supported by all models, refer to the specific
    model documentation for more details.
    For instance, Whisper only supports `to_language=en`.
    """

2349
    stream: bool | None = False
2350
    """Custom field not present in the original OpenAI definition. When set,
2351
    it will enable output to be streamed in a similar fashion as the Chat
2352
    Completion endpoint.
2353
2354
    """
    # Flattened stream option to simplify form data.
2355
2356
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
2357
2358
2359

    max_completion_tokens: int | None = None
    """The maximum number of tokens to generate."""
2360
2361
2362
2363
2364
2365
2366
2367
    # --8<-- [end:translation-extra-params]

    # Default sampling parameters for translation requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "temperature": 0,
    }

    def to_sampling_params(
2368
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2369
    ) -> SamplingParams:
2370
2371
2372
2373
2374
2375
2376
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2377
2378
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2379

2380
2381
2382
2383
2384
2385
2386
        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
2387
            skip_clone=True,  # Created fresh per request, safe to skip clone
2388
        )
2389
2390
2391
2392
2393
2394
2395

    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2396
2397
2398
2399
2400
2401
2402
2403
2404
            # Find which specific stream option was set
            invalid_param = next(
                (so for so in stream_opts if data.get(so, False)),
                "stream_include_usage",
            )
            raise VLLMValidationError(
                "Stream options can only be defined when `stream=True`.",
                parameter=invalid_param,
            )
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429

        return data


# Translation response objects
class TranslationResponse(OpenAIBaseModel):
    text: str
    """The translated text."""


class TranslationWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranslationSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

2430
    avg_logprob: float | None = None
2431
2432
2433
2434
2435
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

2436
    compression_ratio: float | None = None
2437
2438
2439
2440
2441
2442
2443
2444
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

2445
    no_speech_prob: float | None = None
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

    tokens: list[int]
    """Array of token IDs for the text content."""


class TranslationResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The translated text."""

2478
    segments: list[TranslationSegment] | None = None
2479
2480
    """Segments of the translated text and their corresponding details."""

2481
    words: list[TranslationWord] | None = None
2482
    """Extracted words and their corresponding timestamps."""
2483
2484


2485
2486
2487
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose


2488
2489
2490
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
    request_id: str = Field(
2491
        default_factory=random_uuid,
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    token_ids: list[int]
    """The token ids to generate text from."""

    # features: MultiModalFeatureSpec
    # TODO (NickLucche): implement once Renderer work is completed
    features: str | None = None
    """The processed MM inputs for the model."""

    sampling_params: SamplingParams
    """The sampling parameters for the model."""

    model: str | None = None

    stream: bool | None = False
    stream_options: StreamOptions | None = None
    cache_salt: str | None = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit)."
        ),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."
        ),
    )
    kv_transfer_params: dict[str, Any] | None = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.",
    )


class GenerateResponseChoice(BaseModel):
    index: int
    logprobs: ChatCompletionLogProbs | None = None
    # per OpenAI spec this is the default
    finish_reason: str | None = "stop"
    token_ids: list[int] | None = None


class GenerateResponse(BaseModel):
    request_id: str = Field(
2548
        default_factory=random_uuid,
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    choices: list[GenerateResponseChoice]

    prompt_logprobs: list[dict[int, Logprob] | None] | None = None

    kv_transfer_params: dict[str, Any] | None = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.",
    )