protocol.py 90.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
6
import json
Zhuohan Li's avatar
Zhuohan Li committed
7
import time
8
from http import HTTPStatus
9
from typing import Annotated, Any, ClassVar, Literal, TypeAlias
Zhuohan Li's avatar
Zhuohan Li committed
10

11
import regex as re
12
import torch
13
from fastapi import HTTPException, UploadFile
14
from openai.types.chat.chat_completion_audio import (
15
16
17
    ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
18
19
20
21
22
from openai.types.responses import (
    ResponseCodeInterpreterCallCodeDeltaEvent,
    ResponseCodeInterpreterCallCodeDoneEvent,
    ResponseCodeInterpreterCallCompletedEvent,
    ResponseCodeInterpreterCallInProgressEvent,
23
24
25
26
27
    ResponseCodeInterpreterCallInterpretingEvent,
    ResponseContentPartAddedEvent,
    ResponseContentPartDoneEvent,
    ResponseFunctionToolCall,
    ResponseInputItemParam,
28
29
30
31
    ResponseMcpCallArgumentsDeltaEvent,
    ResponseMcpCallArgumentsDoneEvent,
    ResponseMcpCallCompletedEvent,
    ResponseMcpCallInProgressEvent,
32
33
34
35
36
37
38
39
40
41
42
    ResponseOutputItem,
    ResponseOutputItemAddedEvent,
    ResponseOutputItemDoneEvent,
    ResponsePrompt,
    ResponseReasoningTextDeltaEvent,
    ResponseReasoningTextDoneEvent,
    ResponseStatus,
    ResponseWebSearchCallCompletedEvent,
    ResponseWebSearchCallInProgressEvent,
    ResponseWebSearchCallSearchingEvent,
)
43
from openai.types.responses import (
44
45
46
    ResponseCompletedEvent as OpenAIResponseCompletedEvent,
)
from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent
47
from openai.types.responses import (
48
49
    ResponseInProgressEvent as OpenAIResponseInProgressEvent,
)
50
from openai.types.responses.response_reasoning_item import (
51
52
    Content as ResponseReasoningTextContent,
)
53
from openai_harmony import Message as OpenAIHarmonyMessage
54
55
56
57
58

# Backward compatibility for OpenAI client versions
try:  # For older openai versions (< 1.100.0)
    from openai.types.responses import ResponseTextConfig
except ImportError:  # For newer openai versions (>= 1.100.0)
59
    from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
60

61

62
from openai.types.responses.response import IncompleteDetails, ToolChoice
63
64
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
65
66
67
68
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
69
    ValidationError,
70
    field_serializer,
71
72
    model_validator,
)
Zhuohan Li's avatar
Zhuohan Li committed
73

74
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
75
from vllm.logger import init_logger
76
from vllm.logprobs import Logprob
77
78
79
80
81
82
from vllm.sampling_params import (
    BeamSearchParams,
    RequestOutputKind,
    SamplingParams,
    StructuredOutputsParams,
)
83
84
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
85

86
87
logger = init_logger(__name__)

88
_LONG_INFO = torch.iinfo(torch.long)
89

Zhuohan Li's avatar
Zhuohan Li committed
90

91
class OpenAIBaseModel(BaseModel):
92
93
94
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

95
    # Cache class field names
96
    field_names: ClassVar[set[str] | None] = None
97

98
    @model_validator(mode="wrap")
99
    @classmethod
100
101
102
103
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
104
105
        field_names = cls.field_names
        if field_names is None:
106
107
108
109
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
110
                if alias := getattr(field, "alias", None):
111
112
113
114
115
116
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
117
                "The following fields were present in the request but ignored: %s",
118
119
                data.keys() - field_names,
            )
120
        return result
121
122


123
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
124
125
    message: str
    type: str
126
    param: str | None = None
127
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
128
129


130
131
132
133
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class VLLMValidationError(ValueError):
    """vLLM-specific validation error for request validation failures.

    Args:
        message: The error message describing the validation failure.
        parameter: Optional parameter name that failed validation.
        value: Optional value that was rejected during validation.
    """

    def __init__(
        self,
        message: str,
        *,
        parameter: str | None = None,
        value: Any = None,
    ) -> None:
        super().__init__(message)
        self.parameter = parameter
        self.value = value

    def __str__(self):
        base = super().__str__()
        extras = []
        if self.parameter is not None:
            extras.append(f"parameter={self.parameter}")
        if self.value is not None:
            extras.append(f"value={self.value}")
        return f"{base} ({', '.join(extras)})" if extras else base


164
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
165
166
167
168
169
170
171
172
173
174
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
175
    group: str | None = None
176
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
177
178


179
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
180
181
182
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
183
    owned_by: str = "vllm"
184
185
186
    root: str | None = None
    parent: str | None = None
    max_model_len: int | None = None
187
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
188
189


190
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
191
    object: str = "list"
192
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
193
194


195
class PromptTokenUsageInfo(OpenAIBaseModel):
196
    cached_tokens: int | None = None
197
198


199
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
200
201
    prompt_tokens: int = 0
    total_tokens: int = 0
202
203
    completion_tokens: int | None = 0
    prompt_tokens_details: PromptTokenUsageInfo | None = None
Zhuohan Li's avatar
Zhuohan Li committed
204
205


206
207
class RequestResponseMetadata(BaseModel):
    request_id: str
208
    final_usage_info: UsageInfo | None = None
209
210


211
212
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
213
    description: str | None = None
214
215
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
216
217
    json_schema: dict[str, Any] | None = Field(default=None, alias="schema")
    strict: bool | None = None
218
219


220
class LegacyStructuralTag(OpenAIBaseModel):
221
222
223
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
224
    structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema")
225
226
227
    end: str


228
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
229
    type: Literal["structural_tag"]
230
    structures: list[LegacyStructuralTag]
231
232
233
    triggers: list[str]


234
235
236
237
238
239
240
241
242
243
class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    format: Any


AnyStructuralTagResponseFormat: TypeAlias = (
    LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
)


244
class ResponseFormat(OpenAIBaseModel):
245
    # type must be "json_schema", "json_object", or "text"
246
    type: Literal["text", "json_object", "json_schema"]
247
    json_schema: JsonSchemaResponseFormat | None = None
248
249


250
251
252
AnyResponseFormat: TypeAlias = (
    ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
)
253
254


255
class StreamOptions(OpenAIBaseModel):
256
257
    include_usage: bool | None = True
    continuous_usage_stats: bool | None = False
258
259


260
261
class FunctionDefinition(OpenAIBaseModel):
    name: str
262
263
    description: str | None = None
    parameters: dict[str, Any] | None = None
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279


class ChatCompletionToolsParam(OpenAIBaseModel):
    type: Literal["function"] = "function"
    function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
    name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
    function: ChatCompletionNamedFunction
    type: Literal["function"] = "function"


280
281
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
282
283
class LogitsProcessorConstructor(BaseModel):
    qualname: str
284
285
    args: list[Any] | None = None
    kwargs: dict[str, Any] | None = None
286

287
288
    model_config = ConfigDict(extra="forbid")

289

290
LogitsProcessors = list[str | LogitsProcessorConstructor]
291
292


293
def get_logits_processors(
294
295
    processors: LogitsProcessors | None, pattern: str | None
) -> list[Any] | None:
296
297
298
    if processors and pattern:
        logits_processors = []
        for processor in processors:
299
            qualname = processor if isinstance(processor, str) else processor.qualname
300
301
302
303
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
304
305
                    "for more information."
                )
306
307
308
309
310
311
312
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
313
314
315
                logits_processor = logits_processor(
                    *processor.args or [], **processor.kwargs or {}
                )
316
317
318
319
320
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
321
            "server. See --logits-processor-pattern engine argument "
322
323
            "for more information."
        )
324
325
326
    return None


327
ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem
328
329


330
331
332
class ResponsesRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/responses/create
333
334
    background: bool | None = False
    include: (
335
336
337
338
339
340
341
342
343
344
        list[
            Literal[
                "code_interpreter_call.outputs",
                "computer_call_output.output.image_url",
                "file_search_call.results",
                "message.input_image.image_url",
                "message.output_text.logprobs",
                "reasoning.encrypted_content",
            ],
        ]
345
346
347
348
349
350
351
352
        | None
    ) = None
    input: str | list[ResponseInputOutputItem]
    instructions: str | None = None
    max_output_tokens: int | None = None
    max_tool_calls: int | None = None
    metadata: Metadata | None = None
    model: str | None = None
353
    logit_bias: dict[str, float] | None = None
354
355
356
357
    parallel_tool_calls: bool | None = True
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
358
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto"
359
360
361
362
    store: bool | None = True
    stream: bool | None = False
    temperature: float | None = None
    text: ResponseTextConfig | None = None
363
364
    tool_choice: ToolChoice = "auto"
    tools: list[Tool] = Field(default_factory=list)
365
366
    top_logprobs: int | None = 0
    top_p: float | None = None
367
    top_k: int | None = None
368
369
    truncation: Literal["auto", "disabled"] | None = "disabled"
    user: str | None = None
370
371
372
373
374
375
376

    # --8<-- [start:responses-extra-params]
    request_id: str = Field(
        default_factory=lambda: f"resp_{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
377
378
            "through out the inference process and return in response."
        ),
379
    )
380
    mm_processor_kwargs: dict[str, Any] | None = Field(
381
382
383
384
385
386
387
388
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
389
390
            "if the served model does not use priority scheduling."
        ),
391
    )
392
    cache_salt: str | None = Field(
393
394
395
396
397
398
399
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
400
            "to 256 bit)."
401
402
        ),
    )
403
404
405
406
407

    enable_response_messages: bool = Field(
        default=False,
        description=(
            "Dictates whether or not to return messages as part of the "
408
            "response object. Currently only supported for"
409
410
411
            "non-background and gpt-oss only. "
        ),
    )
412
413
414
415
416
    # similar to input_messages / output_messages in ResponsesResponse
    # we take in previous_input_messages (ie in harmony format)
    # this cannot be used in conjunction with previous_response_id
    # TODO: consider supporting non harmony messages as well
    previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
417
418
419
420
421
    # --8<-- [end:responses-extra-params]

    _DEFAULT_SAMPLING_PARAMS = {
        "temperature": 1.0,
        "top_p": 1.0,
422
        "top_k": 0,
423
424
425
426
427
    }

    def to_sampling_params(
        self,
        default_max_tokens: int,
428
        default_sampling_params: dict | None = None,
429
430
431
432
433
434
435
436
437
    ) -> SamplingParams:
        if self.max_output_tokens is None:
            max_tokens = default_max_tokens
        else:
            max_tokens = min(self.max_output_tokens, default_max_tokens)

        default_sampling_params = default_sampling_params or {}
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
438
439
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
440
441
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
442
443
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
444
445
446
447
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
448
        stop_token_ids = default_sampling_params.get("stop_token_ids")
449
450

        # Structured output
451
        structured_outputs = None
452
453
        if self.text is not None and self.text.format is not None:
            response_format = self.text.format
454
455
456
457
            if (
                response_format.type == "json_schema"
                and response_format.schema_ is not None
            ):
458
                structured_outputs = StructuredOutputsParams(
459
460
                    json=response_format.schema_
                )
461
462
463
464
465
466
467
            elif response_format.type == "json_object":
                raise NotImplementedError("json_object is not supported")

        # TODO: add more parameters
        return SamplingParams.from_optional(
            temperature=temperature,
            top_p=top_p,
468
            top_k=top_k,
469
            max_tokens=max_tokens,
470
            logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
471
            stop_token_ids=stop_token_ids,
472
473
474
            output_kind=(
                RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
            ),
475
            structured_outputs=structured_outputs,
476
            logit_bias=self.logit_bias,
477
            skip_clone=True,  # Created fresh per request, safe to skip clone
478
479
        )

480
481
482
483
    def is_include_output_logprobs(self) -> bool:
        """Check if the request includes output logprobs."""
        if self.include is None:
            return False
484
485
486
487
        return (
            isinstance(self.include, list)
            and "message.output_text.logprobs" in self.include
        )
488

489
490
491
492
493
    @model_validator(mode="before")
    def validate_background(cls, data):
        if not data.get("background"):
            return data
        if not data.get("store", True):
494
            raise ValueError("background can only be used when `store` is true")
495
496
497
498
499
        return data

    @model_validator(mode="before")
    def validate_prompt(cls, data):
        if data.get("prompt") is not None:
500
501
502
            raise VLLMValidationError(
                "prompt template is not supported", parameter="prompt"
            )
503
504
        return data

505
506
    @model_validator(mode="before")
    def check_cache_salt_support(cls, data):
507
508
509
510
511
512
        if data.get("cache_salt") is not None and (
            not isinstance(data["cache_salt"], str) or not data["cache_salt"]
        ):
            raise ValueError(
                "Parameter 'cache_salt' must be a non-empty string if provided."
            )
513
514
        return data

515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
    @model_validator(mode="before")
    def function_call_parsing(cls, data):
        """Parse function_call dictionaries into ResponseFunctionToolCall objects.
        This ensures Pydantic can properly resolve union types in the input field.
        Function calls provided as dicts are converted to ResponseFunctionToolCall
        objects before validation, while invalid structures are left for Pydantic
        to reject with appropriate error messages.
        """

        input_data = data.get("input")

        # Early return for None, strings, or bytes
        # (strings are iterable but shouldn't be processed)
        if input_data is None or isinstance(input_data, (str, bytes)):
            return data

        # Convert iterators (like ValidatorIterator) to list
        if not isinstance(input_data, list):
            try:
                input_data = list(input_data)
            except TypeError:
                # Not iterable, leave as-is for Pydantic to handle
                return data

        processed_input = []
        for item in input_data:
            if isinstance(item, dict) and item.get("type") == "function_call":
                try:
                    processed_input.append(ResponseFunctionToolCall(**item))
                except ValidationError:
                    # Let Pydantic handle validation for malformed function calls
                    logger.debug(
                        "Failed to parse function_call to ResponseFunctionToolCall, "
                        "leaving for Pydantic validation"
                    )
                    processed_input.append(item)
            else:
                processed_input.append(item)

        data["input"] = processed_input
        return data

557

558
class ChatCompletionRequest(OpenAIBaseModel):
559
560
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
561
    messages: list[ChatCompletionMessageParam]
562
563
564
565
566
567
    model: str | None = None
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: bool | None = False
    top_logprobs: int | None = 0
    max_tokens: int | None = Field(
568
        default=None,
569
570
        deprecated="max_tokens is deprecated in favor of "
        "the max_completion_tokens field",
571
    )
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
    max_completion_tokens: int | None = None
    n: int | None = 1
    presence_penalty: float | None = 0.0
    response_format: AnyResponseFormat | None = None
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    temperature: float | None = None
    top_p: float | None = None
    tools: list[ChatCompletionToolsParam] | None = None
    tool_choice: (
        Literal["none"]
        | Literal["auto"]
        | Literal["required"]
        | ChatCompletionNamedToolChoiceParam
        | None
    ) = "none"
    reasoning_effort: Literal["low", "medium", "high"] | None = None
591
    include_reasoning: bool = True
592
    parallel_tool_calls: bool | None = True
593

594
    # NOTE this will be ignored by vLLM
595
    user: str | None = None
596

597
    # --8<-- [start:chat-completion-sampling-params]
598
    use_beam_search: bool = False
599
600
601
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
602
    length_penalty: float = 1.0
603
    stop_token_ids: list[int] | None = []
604
605
606
607
608
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
609
610
611
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    prompt_logprobs: int | None = None
    allowed_token_ids: list[int] | None = None
612
    bad_words: list[str] = Field(default_factory=list)
613
    # --8<-- [end:chat-completion-sampling-params]
614

615
    # --8<-- [start:chat-completion-extra-params]
616
    echo: bool = Field(
617
618
619
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
620
621
            "if they belong to the same role."
        ),
622
    )
623
    add_generation_prompt: bool = Field(
624
        default=True,
625
626
627
628
629
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
630
    )
631
632
    continue_final_message: bool = Field(
        default=False,
633
634
635
636
637
638
639
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
640
    )
641
    add_special_tokens: bool = Field(
642
643
644
645
646
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
647
            "special tokens so this should be set to false (as is the "
648
649
            "default)."
        ),
650
    )
651
    documents: list[dict[str, str]] | None = Field(
652
        default=None,
653
654
655
656
657
658
659
        description=(
            "A list of dicts representing documents that will be accessible to "
            "the model if it is performing RAG (retrieval-augmented generation)."
            " If the template does not support RAG, this argument will have no "
            "effect. We recommend that each document should be a dict containing "
            '"title" and "text" keys.'
        ),
660
    )
661
    chat_template: str | None = Field(
662
663
664
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
665
666
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
667
668
            "does not define one."
        ),
669
    )
670
    chat_template_kwargs: dict[str, Any] | None = Field(
671
        default=None,
672
673
        description=(
            "Additional keyword args to pass to the template renderer. "
674
675
            "Will be accessible by the chat template."
        ),
676
    )
677
    mm_processor_kwargs: dict[str, Any] | None = Field(
678
679
680
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
681
    structured_outputs: StructuredOutputsParams | None = Field(
682
        default=None,
683
        description="Additional kwargs for structured outputs",
684
    )
685
686
687
688
689
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
690
691
            "if the served model does not use priority scheduling."
        ),
692
    )
693
    request_id: str = Field(
694
        default_factory=random_uuid,
695
696
697
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
698
699
            "through out the inference process and return in response."
        ),
700
    )
701
    logits_processors: LogitsProcessors | None = Field(
702
703
704
705
706
707
708
709
710
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
711
712
713
            "{'param': 'value'}}."
        ),
    )
714
    return_tokens_as_token_ids: bool | None = Field(
715
716
717
718
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
719
720
721
            "that are not JSON-encodable can be identified."
        ),
    )
722
    return_token_ids: bool | None = Field(
723
724
725
726
727
728
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
729
730
731
            "need to map generated text back to input tokens."
        ),
    )
732
    cache_salt: str | None = Field(
733
734
735
736
737
738
739
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
740
            "to 256 bit)."
741
742
        ),
    )
743
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
744
        default=None,
745
746
        description="KVTransfer parameters used for disaggregated serving.",
    )
747

748
    vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
749
        default=None,
750
        description=(
751
            "Additional request parameters with (list of) string or "
752
753
            "numeric values, used by custom extensions."
        ),
754
755
    )

756
    # --8<-- [end:chat-completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
757

758
759
760
761
762
    # Default sampling parameters for chat completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
763
        "top_k": 0,
764
765
766
767
        "min_p": 0.0,
    }

    def to_beam_search_params(
768
769
        self, max_tokens: int, default_sampling_params: dict
    ) -> BeamSearchParams:
770
        n = self.n if self.n is not None else 1
771
772
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
773
774
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
775
776
777
778
779
780

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
781
            length_penalty=self.length_penalty,
782
783
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
784

785
    def to_sampling_params(
786
        self,
787
        max_tokens: int,
788
        logits_processor_pattern: str | None,
789
        default_sampling_params: dict,
790
    ) -> SamplingParams:
791
792
793
794
795
796
797
798
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
799
800
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
801
802
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
803
804
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
805
806
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
807
808
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
809
810
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
811
812
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
813

814
815
816
817
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.top_logprobs

818
        response_format = self.response_format
819
        if response_format is not None:
820
821
822
823
824
825
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
            if response_format.type == "json_object":
                self.structured_outputs.json_object = True
            elif response_format.type == "json_schema":
                json_schema = response_format.json_schema
                assert json_schema is not None
                self.structured_outputs.json = json_schema.json_schema
            elif response_format.type == "structural_tag":
                structural_tag = response_format
                assert structural_tag is not None and isinstance(
                    structural_tag,
                    (
                        LegacyStructuralTagResponseFormat,
                        StructuralTagResponseFormat,
                    ),
                )
                s_tag_obj = structural_tag.model_dump(by_alias=True)
                self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
843

844
845
846
847
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
848
        return SamplingParams.from_optional(
849
850
851
            n=self.n,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
852
853
854
855
856
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
857
            seed=self.seed,
858
859
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
860
            logprobs=self.top_logprobs if self.logprobs else None,
861
            prompt_logprobs=prompt_logprobs,
862
            ignore_eos=self.ignore_eos,
863
            max_tokens=max_tokens,
864
            min_tokens=self.min_tokens,
865
866
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
867
868
869
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
870
            include_stop_str_in_output=self.include_stop_str_in_output,
871
            truncate_prompt_tokens=self.truncate_prompt_tokens,
872
873
874
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
875
            structured_outputs=self.structured_outputs,
Robert Shaw's avatar
Robert Shaw committed
876
            logit_bias=self.logit_bias,
877
            bad_words=self.bad_words,
878
            allowed_token_ids=self.allowed_token_ids,
879
            extra_args=extra_args or None,
880
            skip_clone=True,  # Created fresh per request, safe to skip clone
881
        )
882

883
    @model_validator(mode="before")
884
    @classmethod
885
886
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
887
888
889
890
            raise VLLMValidationError(
                "Stream options can only be defined when `stream=True`.",
                parameter="stream_options",
            )
891
892
893
894
895
896
897

        return data

    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
898
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
899
900
901
                raise VLLMValidationError(
                    "`prompt_logprobs` are not available when `stream=True`.",
                    parameter="prompt_logprobs",
902
                )
903

904
            if prompt_logprobs < 0 and prompt_logprobs != -1:
905
906
907
908
909
                raise VLLMValidationError(
                    "`prompt_logprobs` must be a positive value or -1.",
                    parameter="prompt_logprobs",
                    value=prompt_logprobs,
                )
910
        if (top_logprobs := data.get("top_logprobs")) is not None:
911
            if top_logprobs < 0 and top_logprobs != -1:
912
913
914
915
916
                raise VLLMValidationError(
                    "`top_logprobs` must be a positive value or -1.",
                    parameter="top_logprobs",
                    value=top_logprobs,
                )
917

918
            if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
919
920
921
                raise VLLMValidationError(
                    "when using `top_logprobs`, `logprobs` must be set to true.",
                    parameter="top_logprobs",
922
923
924
                )

        return data
925

926
927
    @model_validator(mode="before")
    @classmethod
928
    def check_structured_outputs_count(cls, data):
929
930
931
        if isinstance(data, ValueError):
            raise data

932
        if data.get("structured_outputs", None) is None:
933
934
            return data

935
        structured_outputs_kwargs = data["structured_outputs"]
936
937
        count = sum(
            structured_outputs_kwargs.get(k) is not None
938
939
            for k in ("json", "regex", "choice")
        )
940
941
        # you can only use one kind of constraints for structured outputs
        if count > 1:
942
            raise ValueError(
943
                "You can only use one kind of constraints for structured "
944
945
                "outputs ('json', 'regex' or 'choice')."
            )
946
947
        # you can only either use structured outputs or tools, not both
        if count > 1 and data.get("tool_choice", "none") not in (
948
949
950
            "none",
            "auto",
            "required",
951
        ):
952
            raise ValueError(
953
                "You can only either use constraints for structured outputs "
954
955
                "or tools, not both."
            )
956
957
958
959
        return data

    @model_validator(mode="before")
    @classmethod
960
961
962
    def check_tool_usage(cls, data):
        # if "tool_choice" is not specified but tools are provided,
        # default to "auto" tool_choice
963
        if "tool_choice" not in data and data.get("tools"):
964
965
            data["tool_choice"] = "auto"

966
        # if "tool_choice" is "none" -- no validation is needed for tools
967
968
969
        if "tool_choice" in data and data["tool_choice"] == "none":
            return data

970
        # if "tool_choice" is specified -- validation
971
        if "tool_choice" in data and data["tool_choice"] is not None:
972
            # ensure that if "tool choice" is specified, tools are present
973
            if "tools" not in data or data["tools"] is None:
974
                raise ValueError("When using `tool_choice`, `tools` must be set.")
975
976

            # make sure that tool choice is either a named tool
977
            # OR that it's set to "auto" or "required"
978
979
980
            if data["tool_choice"] not in ["auto", "required"] and not isinstance(
                data["tool_choice"], dict
            ):
981
                raise ValueError(
982
983
984
                    f"Invalid value for `tool_choice`: {data['tool_choice']}! "
                    'Only named tools, "none", "auto" or "required" '
                    "are supported."
985
                )
986

987
988
989
            # if tool_choice is "required" but the "tools" list is empty,
            # override the data to behave like "none" to align with
            # OpenAI’s behavior.
990
991
992
993
994
            if (
                data["tool_choice"] == "required"
                and isinstance(data["tools"], list)
                and len(data["tools"]) == 0
            ):
995
996
997
998
                data["tool_choice"] = "none"
                del data["tools"]
                return data

999
1000
            # ensure that if "tool_choice" is specified as an object,
            # it matches a valid tool
1001
1002
            correct_usage_message = (
                'Correct usage: `{"type": "function",'
1003
                ' "function": {"name": "my_function"}}`'
1004
            )
1005
1006
            if isinstance(data["tool_choice"], dict):
                valid_tool = False
1007
1008
                function = data["tool_choice"].get("function")
                if not isinstance(function, dict):
1009
                    raise ValueError(
1010
                        f"Invalid value for `function`: `{function}` in "
1011
1012
                        f"`tool_choice`! {correct_usage_message}"
                    )
1013
                if "name" not in function:
1014
1015
1016
1017
                    raise ValueError(
                        f"Expected field `name` in `function` in "
                        f"`tool_choice`! {correct_usage_message}"
                    )
1018
                function_name = function["name"]
1019
                if not isinstance(function_name, str) or len(function_name) == 0:
1020
                    raise ValueError(
1021
                        f"Invalid `name` in `function`: `{function_name}`"
1022
1023
                        f" in `tool_choice`! {correct_usage_message}"
                    )
1024
                for tool in data["tools"]:
1025
                    if tool["function"]["name"] == function_name:
1026
1027
1028
1029
1030
                        valid_tool = True
                        break
                if not valid_tool:
                    raise ValueError(
                        "The tool specified in `tool_choice` does not match any"
1031
1032
                        " of the specified `tools`"
                    )
1033
1034
        return data

1035
1036
1037
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1038
1039
1040
1041
1042
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1043
1044
        return data

1045
1046
1047
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
1048
1049
1050
1051
1052
1053
        if data.get("cache_salt") is not None and (
            not isinstance(data["cache_salt"], str) or not data["cache_salt"]
        ):
            raise ValueError(
                "Parameter 'cache_salt' must be a non-empty string if provided."
            )
1054
1055
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1056

1057
class CompletionRequest(OpenAIBaseModel):
1058
1059
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
1060
1061
1062
1063
1064
1065
1066
    model: str | None = None
    prompt: list[int] | list[list[int]] | str | list[str] | None = None
    echo: bool | None = False
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: int | None = None
    max_tokens: int | None = 16
1067
    n: int = 1
1068
1069
1070
1071
1072
1073
1074
1075
1076
    presence_penalty: float | None = 0.0
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    suffix: str | None = None
    temperature: float | None = None
    top_p: float | None = None
    user: str | None = None
1077

1078
    # --8<-- [start:completion-sampling-params]
1079
    use_beam_search: bool = False
1080
1081
1082
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
1083
    length_penalty: float = 1.0
1084
    stop_token_ids: list[int] | None = []
1085
1086
1087
1088
1089
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
1090
1091
1092
    truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
    allowed_token_ids: list[int] | None = None
    prompt_logprobs: int | None = None
1093
    # --8<-- [end:completion-sampling-params]
1094

1095
    # --8<-- [start:completion-extra-params]
1096
    prompt_embeds: bytes | list[bytes] | None = None
1097
1098
    add_special_tokens: bool = Field(
        default=True,
1099
        description=(
1100
            "If true (the default), special tokens (e.g. BOS) will be added to "
1101
1102
            "the prompt."
        ),
1103
    )
1104
    response_format: AnyResponseFormat | None = Field(
1105
        default=None,
1106
1107
1108
1109
1110
        description=(
            "Similar to chat completion, this parameter specifies the format "
            "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
            ", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
        ),
1111
    )
1112
    structured_outputs: StructuredOutputsParams | None = Field(
1113
        default=None,
1114
        description="Additional kwargs for structured outputs",
1115
    )
1116
1117
1118
1119
1120
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1121
1122
            "if the served model does not use priority scheduling."
        ),
1123
    )
1124
    request_id: str = Field(
1125
        default_factory=random_uuid,
1126
1127
1128
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
1129
1130
            "through out the inference process and return in response."
        ),
1131
    )
1132
    logits_processors: LogitsProcessors | None = Field(
1133
1134
1135
1136
1137
1138
1139
1140
1141
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
1142
1143
1144
            "{'param': 'value'}}."
        ),
    )
1145

1146
    return_tokens_as_token_ids: bool | None = Field(
1147
1148
1149
1150
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
1151
1152
1153
            "that are not JSON-encodable can be identified."
        ),
    )
1154
    return_token_ids: bool | None = Field(
1155
1156
1157
1158
1159
1160
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
1161
1162
1163
            "need to map generated text back to input tokens."
        ),
    )
1164

1165
    cache_salt: str | None = Field(
1166
1167
1168
1169
1170
1171
1172
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
1173
            "to 256 bit)."
1174
1175
        ),
    )
1176

1177
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
1178
        default=None,
1179
1180
        description="KVTransfer parameters used for disaggregated serving.",
    )
Robert Shaw's avatar
Robert Shaw committed
1181

1182
    vllm_xargs: dict[str, str | int | float] | None = Field(
1183
        default=None,
1184
1185
1186
1187
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
1188
1189
    )

1190
    # --8<-- [end:completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
1191

1192
1193
1194
1195
1196
    # Default sampling parameters for completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
1197
        "top_k": 0,
1198
1199
1200
1201
        "min_p": 0.0,
    }

    def to_beam_search_params(
1202
1203
        self,
        max_tokens: int,
1204
        default_sampling_params: dict | None = None,
1205
1206
1207
    ) -> BeamSearchParams:
        if default_sampling_params is None:
            default_sampling_params = {}
1208
        n = self.n if self.n is not None else 1
1209
1210
1211

        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get("temperature", 1.0)
1212
1213
1214
1215
1216
1217

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
1218
            length_penalty=self.length_penalty,
1219
1220
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
1221

1222
    def to_sampling_params(
1223
        self,
1224
        max_tokens: int,
1225
1226
        logits_processor_pattern: str | None,
        default_sampling_params: dict | None = None,
1227
    ) -> SamplingParams:
1228
1229
        if default_sampling_params is None:
            default_sampling_params = {}
1230

1231
1232
1233
1234
1235
1236
1237
1238
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
1239
1240
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
1241
1242
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
1243
1244
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
1245
1246
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
1247
1248
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
1249
1250
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
1251
1252
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
1253

1254
1255
1256
1257
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

1258
1259
        echo_without_generation = self.echo and self.max_tokens == 0

1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
        response_format = self.response_format
        if response_format is not None:
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
            if response_format.type == "json_object":
                self.structured_outputs.json_object = True
            elif response_format.type == "json_schema":
                json_schema = response_format.json_schema
1272
                assert json_schema is not None
1273
1274
1275
                self.structured_outputs.json = json_schema.json_schema
            elif response_format.type == "structural_tag":
                structural_tag = response_format
1276
                assert structural_tag is not None and isinstance(
1277
1278
1279
1280
1281
                    structural_tag,
                    (
                        LegacyStructuralTagResponseFormat,
                        StructuralTagResponseFormat,
                    ),
1282
1283
                )
                s_tag_obj = structural_tag.model_dump(by_alias=True)
1284
                self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
1285

1286
1287
1288
1289
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
1290
        return SamplingParams.from_optional(
1291
1292
1293
            n=self.n,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
1294
1295
1296
1297
1298
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
1299
            seed=self.seed,
1300
1301
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
1302
            logprobs=self.logprobs,
1303
            ignore_eos=self.ignore_eos,
1304
            max_tokens=max_tokens if not echo_without_generation else 1,
1305
            min_tokens=self.min_tokens,
1306
            prompt_logprobs=prompt_logprobs,
1307
            skip_special_tokens=self.skip_special_tokens,
1308
            spaces_between_special_tokens=self.spaces_between_special_tokens,
1309
            include_stop_str_in_output=self.include_stop_str_in_output,
1310
1311
1312
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
1313
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1314
1315
1316
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
1317
            structured_outputs=self.structured_outputs,
1318
            logit_bias=self.logit_bias,
Robert Shaw's avatar
Robert Shaw committed
1319
            allowed_token_ids=self.allowed_token_ids,
1320
            extra_args=extra_args or None,
1321
            skip_clone=True,  # Created fresh per request, safe to skip clone
1322
        )
1323

1324
1325
    @model_validator(mode="before")
    @classmethod
1326
    def check_structured_outputs_count(cls, data):
1327
        if data.get("structured_outputs", None) is None:
1328
1329
            return data

1330
        structured_outputs_kwargs = data["structured_outputs"]
1331
1332
        count = sum(
            structured_outputs_kwargs.get(k) is not None
1333
1334
            for k in ("json", "regex", "choice")
        )
1335
        if count > 1:
1336
            raise VLLMValidationError(
1337
                "You can only use one kind of constraints for structured "
1338
1339
                "outputs ('json', 'regex' or 'choice').",
                parameter="structured_outputs",
1340
            )
1341
1342
        return data

1343
1344
1345
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
1346
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
1347
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
1348
1349
1350
                raise VLLMValidationError(
                    "`prompt_logprobs` are not available when `stream=True`.",
                    parameter="prompt_logprobs",
1351
                )
1352

1353
            if prompt_logprobs < 0 and prompt_logprobs != -1:
1354
1355
1356
1357
1358
                raise VLLMValidationError(
                    "`prompt_logprobs` must be a positive value or -1.",
                    parameter="prompt_logprobs",
                    value=prompt_logprobs,
                )
1359
        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
1360
1361
1362
1363
1364
            raise VLLMValidationError(
                "`logprobs` must be a positive value.",
                parameter="logprobs",
                value=logprobs,
            )
1365

1366
1367
        return data

1368
1369
1370
1371
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
1372
1373
1374
1375
            raise VLLMValidationError(
                "Stream options can only be defined when `stream=True`.",
                parameter="stream_options",
            )
1376

1377
1378
        return data

1379
1380
1381
    @model_validator(mode="before")
    @classmethod
    def validate_prompt_and_prompt_embeds(cls, data):
1382
1383
1384
        prompt = data.get("prompt")
        prompt_embeds = data.get("prompt_embeds")

1385
1386
1387
1388
        prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
        embeds_is_empty = prompt_embeds is None or (
            isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
        )
1389
1390

        if prompt_is_empty and embeds_is_empty:
1391
            raise ValueError(
1392
1393
1394
                "Either prompt or prompt_embeds must be provided and non-empty."
            )

1395
1396
        return data

1397
1398
1399
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
1400
1401
1402
1403
1404
1405
        if data.get("cache_salt") is not None and (
            not isinstance(data["cache_salt"], str) or not data["cache_salt"]
        ):
            raise ValueError(
                "Parameter 'cache_salt' must be a non-empty string if provided."
            )
1406
1407
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1408

1409
class CompletionLogProbs(OpenAIBaseModel):
1410
    text_offset: list[int] = Field(default_factory=list)
1411
    token_logprobs: list[float | None] = Field(default_factory=list)
1412
    tokens: list[str] = Field(default_factory=list)
1413
    top_logprobs: list[dict[str, float] | None] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
1414
1415


1416
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1417
1418
    index: int
    text: str
1419
1420
1421
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1422
1423
1424
1425
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1426
1427
            "including encountering the EOS token"
        ),
1428
    )
1429
1430
1431
    token_ids: list[int] | None = None  # For response
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None  # For prompt
Zhuohan Li's avatar
Zhuohan Li committed
1432
1433


1434
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1435
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
1436
    object: Literal["text_completion"] = "text_completion"
Zhuohan Li's avatar
Zhuohan Li committed
1437
1438
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1439
    choices: list[CompletionResponseChoice]
1440
1441
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1442
    usage: UsageInfo
1443
1444

    # vLLM-specific fields that are not in OpenAI spec
1445
    kv_transfer_params: dict[str, Any] | None = Field(
1446
1447
        default=None, description="KVTransfer parameters."
    )
Zhuohan Li's avatar
Zhuohan Li committed
1448
1449


1450
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1451
1452
    index: int
    text: str
1453
1454
1455
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
1456
1457
1458
1459
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
1460
1461
            "including encountering the EOS token"
        ),
1462
    )
1463
1464
    # not part of the OpenAI spec but for tracing the tokens
    # prompt tokens is put into choice to align with CompletionResponseChoice
1465
1466
    prompt_token_ids: list[int] | None = None
    token_ids: list[int] | None = None
Zhuohan Li's avatar
Zhuohan Li committed
1467
1468


1469
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1470
1471
1472
1473
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1474
    choices: list[CompletionResponseStreamChoice]
1475
    usage: UsageInfo | None = Field(default=None)
1476
1477


1478
1479
1480
1481
1482
1483
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
1484
    id: str = Field(default_factory=make_tool_call_id)
1485
1486
1487
1488
    type: Literal["function"] = "function"
    function: FunctionCall


1489
class DeltaFunctionCall(BaseModel):
1490
1491
    name: str | None = None
    arguments: str | None = None
1492
1493
1494
1495


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
1496
1497
    id: str | None = None
    type: Literal["function"] | None = None
1498
    index: int
1499
    function: DeltaFunctionCall | None = None
1500
1501
1502
1503
1504
1505
1506


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
1507
    tool_calls: list[ToolCall]
1508
1509
1510

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
1511
    content: str | None = None
1512
1513


1514
class ChatMessage(OpenAIBaseModel):
1515
    role: str
1516
1517
1518
1519
1520
    content: str | None = None
    refusal: str | None = None
    annotations: OpenAIAnnotation | None = None
    audio: OpenAIChatCompletionAudio | None = None
    function_call: FunctionCall | None = None
1521
    tool_calls: list[ToolCall] = Field(default_factory=list)
1522

1523
    # vLLM-specific fields that are not in OpenAI spec
1524
    reasoning: str | None = None
1525
    reasoning_content: str | None = None
1526
1527
1528
1529
1530
1531
1532
    """Deprecated: use `reasoning` instead."""

    @model_validator(mode="after")
    def handle_deprecated_reasoning_content(self):
        """Copy reasoning to reasoning_content for backward compatibility."""
        self.reasoning_content = self.reasoning
        return self
1533

1534

1535
1536
1537
class ChatCompletionLogProb(OpenAIBaseModel):
    token: str
    logprob: float = -9999.0
1538
    bytes: list[int] | None = None
1539
1540
1541


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
1542
1543
    # Workaround: redefine fields name cache so that it's not
    # shared with the super class.
1544
    field_names: ClassVar[set[str] | None] = None
1545
    top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
1546
1547
1548


class ChatCompletionLogProbs(OpenAIBaseModel):
1549
    content: list[ChatCompletionLogProbsContent] | None = None
1550
1551


1552
class ChatCompletionResponseChoice(OpenAIBaseModel):
1553
1554
    index: int
    message: ChatMessage
1555
    logprobs: ChatCompletionLogProbs | None = None
1556
    # per OpenAI spec this is the default
1557
    finish_reason: str | None = "stop"
1558
    # not part of the OpenAI spec but included in vLLM for legacy reasons
1559
    stop_reason: int | str | None = None
1560
1561
    # not part of the OpenAI spec but is useful for tracing the tokens
    # in agent scenarios
1562
    token_ids: list[int] | None = None
1563
1564


1565
class ChatCompletionResponse(OpenAIBaseModel):
1566
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1567
    object: Literal["chat.completion"] = "chat.completion"
1568
1569
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1570
    choices: list[ChatCompletionResponseChoice]
1571
1572
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
1573
    usage: UsageInfo
1574
1575

    # vLLM-specific fields that are not in OpenAI spec
1576
1577
1578
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None
    kv_transfer_params: dict[str, Any] | None = Field(
1579
1580
        default=None, description="KVTransfer parameters."
    )
1581
1582


1583
class DeltaMessage(OpenAIBaseModel):
1584
1585
    role: str | None = None
    content: str | None = None
1586
    reasoning: str | None = None
1587
    reasoning_content: str | None = None
1588
    """Deprecated: use `reasoning` instead."""
1589
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
1590

1591
1592
1593
1594
1595
1596
    @model_validator(mode="after")
    def handle_deprecated_reasoning_content(self):
        """Copy reasoning to reasoning_content for backward compatibility."""
        self.reasoning_content = self.reasoning
        return self

1597

1598
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
1599
1600
    index: int
    delta: DeltaMessage
1601
1602
1603
    logprobs: ChatCompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = None
1604
    # not part of the OpenAI spec but for tracing the tokens
1605
    token_ids: list[int] | None = None
1606
1607


1608
class ChatCompletionStreamResponse(OpenAIBaseModel):
1609
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1610
    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
1611
1612
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1613
    choices: list[ChatCompletionResponseStreamChoice]
1614
    usage: UsageInfo | None = Field(default=None)
1615
    # not part of the OpenAI spec but for tracing the tokens
1616
    prompt_token_ids: list[int] | None = None
1617
1618


1619
1620
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
1621
1622
    finish_reason: str | None = None
    stop_reason: int | str | None = None
1623
1624
1625
1626
1627
1628
1629
1630


class TranscriptionStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
    object: Literal["transcription.chunk"] = "transcription.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranscriptionResponseStreamChoice]
1631
    usage: UsageInfo | None = Field(default=None)
1632
1633


1634
1635
class InputTokensDetails(OpenAIBaseModel):
    cached_tokens: int
1636
1637
    input_tokens_per_turn: list[int] = Field(default_factory=list)
    cached_tokens_per_turn: list[int] = Field(default_factory=list)
1638
1639
1640


class OutputTokensDetails(OpenAIBaseModel):
1641
1642
    reasoning_tokens: int = 0
    tool_output_tokens: int = 0
1643
1644
    output_tokens_per_turn: list[int] = Field(default_factory=list)
    tool_output_tokens_per_turn: list[int] = Field(default_factory=list)
1645
1646
1647
1648
1649
1650
1651
1652


class ResponseUsage(OpenAIBaseModel):
    input_tokens: int
    input_tokens_details: InputTokensDetails
    output_tokens: int
    output_tokens_details: OutputTokensDetails
    total_tokens: int
1653
1654


1655
1656
1657
1658
1659
1660
def serialize_message(msg):
    """
    Serializes a single message
    """
    if isinstance(msg, dict):
        return msg
1661
    elif hasattr(msg, "to_dict"):
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
        return msg.to_dict()
    else:
        # fallback to pyandic dump
        return msg.model_dump_json()


def serialize_messages(msgs):
    """
    Serializes multiple messages
    """
    return [serialize_message(msg) for msg in msgs] if msgs else None


1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
class ResponseRawMessageAndToken(OpenAIBaseModel):
    """Class to show the raw message.
    If message / tokens diverge, tokens is the source of truth"""

    message: str
    tokens: list[int]
    type: Literal["raw_message_tokens"] = "raw_message_tokens"


ResponseInputOutputMessage: TypeAlias = (
    list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken]
)


1689
1690
1691
1692
class ResponsesResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
    created_at: int = Field(default_factory=lambda: int(time.time()))
    # error: Optional[ResponseError] = None
1693
1694
1695
    incomplete_details: IncompleteDetails | None = None
    instructions: str | None = None
    metadata: Metadata | None = None
1696
1697
    model: str
    object: Literal["response"] = "response"
1698
    output: list[ResponseOutputItem]
1699
1700
1701
1702
1703
1704
1705
    parallel_tool_calls: bool
    temperature: float
    tool_choice: ToolChoice
    tools: list[Tool]
    top_p: float
    background: bool
    max_output_tokens: int
1706
1707
1708
1709
    max_tool_calls: int | None = None
    previous_response_id: str | None = None
    prompt: ResponsePrompt | None = None
    reasoning: Reasoning | None = None
1710
1711
    service_tier: Literal["auto", "default", "flex", "scale", "priority"]
    status: ResponseStatus
1712
1713
    text: ResponseTextConfig | None = None
    top_logprobs: int | None = None
1714
    truncation: Literal["auto", "disabled"]
1715
1716
    usage: ResponseUsage | None = None
    user: str | None = None
1717

1718
    # --8<-- [start:responses-response-extra-params]
1719
1720
1721
    # These are populated when enable_response_messages is set to True
    # NOTE: custom serialization is needed
    # see serialize_input_messages and serialize_output_messages
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
    input_messages: ResponseInputOutputMessage | None = Field(
        default=None,
        description=(
            "If enable_response_messages, we can show raw token input to model."
        ),
    )
    output_messages: ResponseInputOutputMessage | None = Field(
        default=None,
        description=(
            "If enable_response_messages, we can show raw token output of model."
        ),
    )
    # --8<-- [end:responses-response-extra-params]
1735
1736
1737
1738
1739
1740

    # NOTE: openAI harmony doesn't serialize TextContent properly,
    # TODO: this fixes for TextContent, but need to verify for tools etc
    # https://github.com/openai/harmony/issues/78
    @field_serializer("output_messages", when_used="json")
    def serialize_output_messages(self, msgs, _info):
1741
        return serialize_messages(msgs)
1742
1743
1744
1745
1746

    # NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it
    # https://github.com/openai/harmony/issues/78
    @field_serializer("input_messages", when_used="json")
    def serialize_input_messages(self, msgs, _info):
1747
        return serialize_messages(msgs)
1748

1749
1750
1751
1752
1753
1754
1755
1756
1757
    @classmethod
    def from_request(
        cls,
        request: ResponsesRequest,
        sampling_params: SamplingParams,
        model_name: str,
        created_time: int,
        output: list[ResponseOutputItem],
        status: ResponseStatus,
1758
        usage: ResponseUsage | None = None,
1759
1760
        input_messages: ResponseInputOutputMessage | None = None,
        output_messages: ResponseInputOutputMessage | None = None,
1761
    ) -> "ResponsesResponse":
1762
        incomplete_details: IncompleteDetails | None = None
1763
1764
        if status == "incomplete":
            incomplete_details = IncompleteDetails(reason="max_output_tokens")
1765
1766
1767
        # TODO: implement the other reason for incomplete_details,
        # which is content_filter
        # incomplete_details = IncompleteDetails(reason='content_filter')
1768
1769
1770
        return cls(
            id=request.request_id,
            created_at=created_time,
1771
            incomplete_details=incomplete_details,
1772
1773
1774
1775
            instructions=request.instructions,
            metadata=request.metadata,
            model=model_name,
            output=output,
1776
1777
            input_messages=input_messages,
            output_messages=output_messages,
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
            parallel_tool_calls=request.parallel_tool_calls,
            temperature=sampling_params.temperature,
            tool_choice=request.tool_choice,
            tools=request.tools,
            top_p=sampling_params.top_p,
            background=request.background,
            max_output_tokens=sampling_params.max_tokens,
            max_tool_calls=request.max_tool_calls,
            previous_response_id=request.previous_response_id,
            prompt=request.prompt,
            reasoning=request.reasoning,
            service_tier=request.service_tier,
            status=status,
            text=request.text,
            top_logprobs=sampling_params.logprobs,
            truncation=request.truncation,
            user=request.user,
            usage=usage,
        )


1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartDoneEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.done"]
    """The type of the event. Always `response.reasoning_part.done`."""


# TODO: this code can be removed once
# https://github.com/openai/openai-python/issues/2634 has been resolved
class ResponseReasoningPartAddedEvent(OpenAIBaseModel):
    content_index: int
    """The index of the content part that is done."""

    item_id: str
    """The ID of the output item that the content part was added to."""

    output_index: int
    """The index of the output item that the content part was added to."""

    part: ResponseReasoningTextContent
    """The content part that is done."""

    sequence_number: int
    """The sequence number of this event."""

    type: Literal["response.reasoning_part.added"]
    """The type of the event. Always `response.reasoning_part.added`."""


1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
# vLLM Streaming Events
# Note: we override the response type with the vLLM ResponsesResponse type
class ResponseCompletedEvent(OpenAIResponseCompletedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseCreatedEvent(OpenAIResponseCreatedEvent):
    response: ResponsesResponse  # type: ignore[override]


class ResponseInProgressEvent(OpenAIResponseInProgressEvent):
    response: ResponsesResponse  # type: ignore[override]


1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
StreamingResponsesResponse: TypeAlias = (
    ResponseCreatedEvent
    | ResponseInProgressEvent
    | ResponseCompletedEvent
    | ResponseOutputItemAddedEvent
    | ResponseOutputItemDoneEvent
    | ResponseContentPartAddedEvent
    | ResponseContentPartDoneEvent
    | ResponseReasoningTextDeltaEvent
    | ResponseReasoningTextDoneEvent
    | ResponseReasoningPartAddedEvent
    | ResponseReasoningPartDoneEvent
    | ResponseCodeInterpreterCallInProgressEvent
    | ResponseCodeInterpreterCallCodeDeltaEvent
    | ResponseWebSearchCallInProgressEvent
    | ResponseWebSearchCallSearchingEvent
    | ResponseWebSearchCallCompletedEvent
    | ResponseCodeInterpreterCallCodeDoneEvent
    | ResponseCodeInterpreterCallInterpretingEvent
    | ResponseCodeInterpreterCallCompletedEvent
1877
1878
1879
1880
    | ResponseMcpCallArgumentsDeltaEvent
    | ResponseMcpCallArgumentsDoneEvent
    | ResponseMcpCallInProgressEvent
    | ResponseMcpCallCompletedEvent
1881
)
1882

1883

1884
class TokenizeCompletionRequest(OpenAIBaseModel):
1885
    model: str | None = None
1886
1887
    prompt: str

1888
1889
1890
1891
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
1892
1893
            "the prompt."
        ),
1894
    )
1895
    return_token_strs: bool | None = Field(
1896
        default=False,
1897
1898
1899
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
1900
    )
1901
1902
1903


class TokenizeChatRequest(OpenAIBaseModel):
1904
    model: str | None = None
1905
    messages: list[ChatCompletionMessageParam]
1906

1907
1908
    add_generation_prompt: bool = Field(
        default=True,
1909
1910
1911
1912
1913
        description=(
            "If true, the generation prompt will be added to the chat template. "
            "This is a parameter used by chat template in tokenizer config of the "
            "model."
        ),
1914
    )
1915
    return_token_strs: bool | None = Field(
1916
        default=False,
1917
1918
1919
        description=(
            "If true, also return the token strings corresponding to the token ids."
        ),
1920
    )
1921
1922
    continue_final_message: bool = Field(
        default=False,
1923
1924
1925
1926
1927
1928
1929
        description=(
            "If this is set, the chat will be formatted so that the final "
            "message in the chat is open-ended, without any EOS tokens. The "
            "model will continue this message rather than starting a new one. "
            'This allows you to "prefill" part of the model\'s response for it. '
            "Cannot be used at the same time as `add_generation_prompt`."
        ),
1930
1931
1932
1933
1934
1935
1936
1937
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
1938
1939
            "default)."
        ),
1940
    )
1941
    chat_template: str | None = Field(
1942
1943
1944
1945
1946
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
1947
1948
            "does not define one."
        ),
1949
    )
1950
    chat_template_kwargs: dict[str, Any] | None = Field(
1951
        default=None,
1952
1953
        description=(
            "Additional keyword args to pass to the template renderer. "
1954
1955
            "Will be accessible by the chat template."
        ),
1956
    )
1957
    mm_processor_kwargs: dict[str, Any] | None = Field(
1958
1959
1960
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
1961
    tools: list[ChatCompletionToolsParam] | None = Field(
1962
1963
1964
        default=None,
        description=("A list of tools the model may call."),
    )
1965

1966
1967
1968
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
1969
1970
1971
1972
1973
        if data.get("continue_final_message") and data.get("add_generation_prompt"):
            raise ValueError(
                "Cannot set both `continue_final_message` and "
                "`add_generation_prompt` to True."
            )
1974
1975
        return data

1976

1977
TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest
1978
1979
1980
1981
1982


class TokenizeResponse(OpenAIBaseModel):
    count: int
    max_model_len: int
1983
    tokens: list[int]
1984
    token_strs: list[str] | None = None
1985
1986
1987


class DetokenizeRequest(OpenAIBaseModel):
1988
    model: str | None = None
1989
    tokens: list[int]
1990
1991
1992
1993


class DetokenizeResponse(OpenAIBaseModel):
    prompt: str
1994
1995


1996
1997
class TokenizerInfoResponse(OpenAIBaseModel):
    """
1998
    Response containing tokenizer configuration
1999
2000
2001
2002
2003
2004
2005
    equivalent to tokenizer_config.json
    """

    model_config = ConfigDict(extra="allow")
    tokenizer_class: str


2006
class LoadLoRAAdapterRequest(BaseModel):
2007
2008
2009
2010
    lora_name: str
    lora_path: str


2011
class UnloadLoRAAdapterRequest(BaseModel):
2012
    lora_name: str
2013
    lora_int_id: int | None = Field(default=None)
2014
2015
2016


## Protocols for Audio
2017
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
2018
2019
2020
2021


class TranscriptionRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
2022
    # https://platform.openai.com/docs/api-reference/audio/createTranscription
2023
2024
2025
2026
2027
2028
2029

    file: UploadFile
    """
    The audio file object (not file name) to transcribe, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2030
    model: str | None = None
2031
2032
2033
    """ID of the model to use.
    """

2034
    language: str | None = None
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
    """The language of the input audio.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy and latency.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    ## TODO (varun) : Support if set to 0, certain thresholds are met !!

2058
    timestamp_granularities: list[Literal["word", "segment"]] = Field(
2059
2060
        alias="timestamp_granularities[]", default=[]
    )
2061
2062
2063
2064
2065
2066
2067
2068
    """The timestamp granularities to populate for this transcription.

    `response_format` must be set `verbose_json` to use timestamp granularities.
    Either or both of these options are supported: `word`, or `segment`. Note:
    There is no additional latency for segment timestamps, but generating word
    timestamps incurs additional latency.
    """

2069
    stream: bool | None = False
2070
    """When set, it will enable output to be streamed in a similar fashion
2071
    as the Chat Completion endpoint.
2072
    """
2073
    # --8<-- [start:transcription-extra-params]
2074
    # Flattened stream option to simplify form data.
2075
2076
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
2077

2078
    vllm_xargs: dict[str, str | int | float] | None = Field(
2079
        default=None,
2080
2081
2082
2083
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
2084
    )
2085
    # --8<-- [end:transcription-extra-params]
2086

2087
    to_language: str | None = None
2088
2089
    """The language of the output audio we transcribe to.

2090
    Please note that this is not currently used by supported models at this
2091
2092
2093
    time, but it is a placeholder for future use, matching translation api.
    """

2094
    # --8<-- [start:transcription-sampling-params]
2095
2096
2097
2098
2099
2100
2101
2102
2103
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """

2104
    top_p: float | None = None
2105
    """Enables nucleus (top-p) sampling, where tokens are selected from the
2106
2107
2108
    smallest possible set whose cumulative probability exceeds `p`.
    """

2109
    top_k: int | None = None
2110
2111
    """Limits sampling to the `k` most probable tokens at each step."""

2112
    min_p: float | None = None
2113
    """Filters out tokens with a probability lower than `min_p`, ensuring a
2114
2115
2116
    minimum likelihood threshold during sampling.
    """

2117
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2118
2119
    """The seed to use for sampling."""

2120
    frequency_penalty: float | None = 0.0
2121
2122
    """The frequency penalty to use for sampling."""

2123
    repetition_penalty: float | None = None
2124
2125
    """The repetition penalty to use for sampling."""

2126
    presence_penalty: float | None = 0.0
2127
    """The presence penalty to use for sampling."""
2128
2129
2130

    max_completion_tokens: int | None = None
    """The maximum number of tokens to generate."""
2131
    # --8<-- [end:transcription-sampling-params]
2132

2133
2134
    # Default sampling parameters for transcription requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
2135
2136
2137
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
2138
        "top_k": 0,
2139
        "min_p": 0.0,
2140
2141
2142
    }

    def to_sampling_params(
2143
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2144
    ) -> SamplingParams:
2145
2146
2147
2148
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
2149

2150
2151
2152
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2153
2154
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2155
2156
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
2157
2158
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
2159
2160
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
2161
2162
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
2163
2164
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
2165
2166
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
2167
2168
2169
2170

        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )

        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=repetition_penalty,
            presence_penalty=self.presence_penalty,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
            extra_args=self.vllm_xargs,
2188
            skip_clone=True,  # Created fresh per request, safe to skip clone
2189
        )
2190
2191
2192

    @model_validator(mode="before")
    @classmethod
2193
2194
2195
2196
2197
2198
2199
    def validate_transcription_request(cls, data):
        if isinstance(data.get("file"), str):
            raise HTTPException(
                status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
                detail="Expected 'file' to be a file-like object, not 'str'.",
            )

2200
2201
2202
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2203
2204
2205
2206
2207
2208
2209
2210
2211
            # Find which specific stream option was set
            invalid_param = next(
                (so for so in stream_opts if data.get(so, False)),
                "stream_include_usage",
            )
            raise VLLMValidationError(
                "Stream options can only be defined when `stream=True`.",
                parameter=invalid_param,
            )
2212
2213

        return data
2214
2215
2216


# Transcription response objects
2217
2218
2219
2220
2221
class TranscriptionUsageAudio(OpenAIBaseModel):
    type: Literal["duration"] = "duration"
    seconds: int


2222
2223
2224
class TranscriptionResponse(OpenAIBaseModel):
    text: str
    """The transcribed text."""
2225
    usage: TranscriptionUsageAudio
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242


class TranscriptionWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranscriptionSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

2243
    avg_logprob: float | None = None
2244
2245
2246
2247
2248
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

2249
    compression_ratio: float | None = None
2250
2251
2252
2253
2254
2255
2256
2257
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

2258
    no_speech_prob: float | None = None
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

2277
    tokens: list[int]
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
    """Array of token IDs for the text content."""


class TranscriptionResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The transcribed text."""

2291
    segments: list[TranscriptionSegment] | None = None
2292
2293
    """Segments of the transcribed text and their corresponding details."""

2294
    words: list[TranscriptionWord] | None = None
2295
    """Extracted words and their corresponding timestamps."""
2296
2297


2298
2299
2300
2301
2302
TranscriptionResponseVariant: TypeAlias = (
    TranscriptionResponse | TranscriptionResponseVerbose
)


2303
2304
class TranslationResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
2305
2306
    finish_reason: str | None = None
    stop_reason: int | str | None = None
2307
2308
2309
2310
2311
2312
2313
2314


class TranslationStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
    object: Literal["translation.chunk"] = "translation.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranslationResponseStreamChoice]
2315
    usage: UsageInfo | None = Field(default=None)
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327


class TranslationRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/audio/createTranslation

    file: UploadFile
    """
    The audio file object (not file name) to translate, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2328
    model: str | None = None
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
    """ID of the model to use.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    # TODO support additional sampling parameters
    # --8<-- [start:translation-sampling-params]
2348
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
2349
2350
    """The seed to use for sampling."""

2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """
    # --8<-- [end:translation-sampling-params]

    # --8<-- [start:translation-extra-params]
2362
    language: str | None = None
2363
2364
2365
2366
2367
2368
2369
    """The language of the input audio we translate from.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy.
    """

2370
    to_language: str | None = None
2371
2372
2373
2374
2375
2376
2377
    """The language of the input audio we translate to.

    Please note that this is not supported by all models, refer to the specific
    model documentation for more details.
    For instance, Whisper only supports `to_language=en`.
    """

2378
    stream: bool | None = False
2379
    """Custom field not present in the original OpenAI definition. When set,
2380
    it will enable output to be streamed in a similar fashion as the Chat
2381
    Completion endpoint.
2382
2383
    """
    # Flattened stream option to simplify form data.
2384
2385
    stream_include_usage: bool | None = False
    stream_continuous_usage_stats: bool | None = False
2386
2387
2388

    max_completion_tokens: int | None = None
    """The maximum number of tokens to generate."""
2389
2390
2391
2392
2393
2394
2395
2396
    # --8<-- [end:translation-extra-params]

    # Default sampling parameters for translation requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "temperature": 0,
    }

    def to_sampling_params(
2397
        self, default_max_tokens: int, default_sampling_params: dict | None = None
2398
    ) -> SamplingParams:
2399
2400
2401
2402
2403
2404
2405
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
2406
2407
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
2408

2409
2410
2411
2412
2413
2414
2415
        return SamplingParams.from_optional(
            temperature=temperature,
            max_tokens=max_tokens,
            seed=self.seed,
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
2416
            skip_clone=True,  # Created fresh per request, safe to skip clone
2417
        )
2418
2419
2420
2421
2422
2423
2424

    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
2425
2426
2427
2428
2429
2430
2431
2432
2433
            # Find which specific stream option was set
            invalid_param = next(
                (so for so in stream_opts if data.get(so, False)),
                "stream_include_usage",
            )
            raise VLLMValidationError(
                "Stream options can only be defined when `stream=True`.",
                parameter=invalid_param,
            )
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458

        return data


# Translation response objects
class TranslationResponse(OpenAIBaseModel):
    text: str
    """The translated text."""


class TranslationWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranslationSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

2459
    avg_logprob: float | None = None
2460
2461
2462
2463
2464
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

2465
    compression_ratio: float | None = None
2466
2467
2468
2469
2470
2471
2472
2473
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

2474
    no_speech_prob: float | None = None
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

    tokens: list[int]
    """Array of token IDs for the text content."""


class TranslationResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The translated text."""

2507
    segments: list[TranslationSegment] | None = None
2508
2509
    """Segments of the translated text and their corresponding details."""

2510
    words: list[TranslationWord] | None = None
2511
    """Extracted words and their corresponding timestamps."""
2512
2513


2514
2515
2516
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose


2517
2518
2519
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
    request_id: str = Field(
2520
        default_factory=random_uuid,
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    token_ids: list[int]
    """The token ids to generate text from."""

    # features: MultiModalFeatureSpec
    # TODO (NickLucche): implement once Renderer work is completed
    features: str | None = None
    """The processed MM inputs for the model."""

    sampling_params: SamplingParams
    """The sampling parameters for the model."""

    model: str | None = None

    stream: bool | None = False
    stream_options: StreamOptions | None = None
    cache_salt: str | None = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit)."
        ),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."
        ),
    )
    kv_transfer_params: dict[str, Any] | None = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.",
    )


class GenerateResponseChoice(BaseModel):
    index: int
    logprobs: ChatCompletionLogProbs | None = None
    # per OpenAI spec this is the default
    finish_reason: str | None = "stop"
    token_ids: list[int] | None = None


class GenerateResponse(BaseModel):
    request_id: str = Field(
2577
        default_factory=random_uuid,
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    choices: list[GenerateResponseChoice]

    prompt_logprobs: list[dict[int, Logprob] | None] | None = None

    kv_transfer_params: dict[str, Any] | None = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.",
    )