protocol.py 90.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
6
import json
Zhuohan Li's avatar
Zhuohan Li committed
7
import time
8
from http import HTTPStatus
9
from typing import Annotated, Any, ClassVar, Literal, Optional, Union
Zhuohan Li's avatar
Zhuohan Li committed
10

11
import regex as re
12
import torch
13
from fastapi import HTTPException, UploadFile
14
15
16
17
18
19
# yapf: disable
from openai.types.chat.chat_completion_audio import (
    ChatCompletionAudio as OpenAIChatCompletionAudio)
from openai.types.chat.chat_completion_message import (
    Annotation as OpenAIAnnotation)
# yapf: enable
20
21
from openai.types.responses import (ResponseFunctionToolCall,
                                    ResponseInputItemParam, ResponseOutputItem,
22
                                    ResponsePrompt, ResponseReasoningItem,
23
24
25
26
27
28
29
30
31
                                    ResponseStatus)

# Backward compatibility for OpenAI client versions
try:  # For older openai versions (< 1.100.0)
    from openai.types.responses import ResponseTextConfig
except ImportError:  # For newer openai versions (>= 1.100.0)
    from openai.types.responses import (ResponseFormatTextConfig as
                                        ResponseTextConfig)

32
33
34
from openai.types.responses.response import ToolChoice
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
35
36
from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
                      ValidationInfo, field_validator, model_validator)
37
from typing_extensions import TypeAlias
Zhuohan Li's avatar
Zhuohan Li committed
38

39
from vllm import envs
40
41
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
                                         random_tool_call_id)
42
43
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam)
44
from vllm.logger import init_logger
45
from vllm.pooling_params import PoolingParams
46
47
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
48
from vllm.sequence import Logprob
49
from vllm.utils import random_uuid, resolve_obj_by_qualname
50

51
52
logger = init_logger(__name__)

53
_LONG_INFO = torch.iinfo(torch.long)
54

Zhuohan Li's avatar
Zhuohan Li committed
55

56
class OpenAIBaseModel(BaseModel):
57
58
59
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

60
    # Cache class field names
61
    field_names: ClassVar[Optional[set[str]]] = None
62

63
    @model_validator(mode="wrap")
64
    @classmethod
65
66
67
68
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
69
70
        field_names = cls.field_names
        if field_names is None:
71
72
73
74
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
75
                if alias := getattr(field, "alias", None):
76
77
78
79
80
81
82
83
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
                "The following fields were present in the request "
                "but ignored: %s",
84
85
                data.keys() - field_names,
            )
86
        return result
87
88


89
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
90
91
92
    message: str
    type: str
    param: Optional[str] = None
93
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
94
95


96
97
98
99
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


100
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
101
102
103
104
105
106
107
108
109
110
111
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
    group: Optional[str] = None
112
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
113
114


115
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
116
117
118
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
119
    owned_by: str = "vllm"
Zhuohan Li's avatar
Zhuohan Li committed
120
121
    root: Optional[str] = None
    parent: Optional[str] = None
122
    max_model_len: Optional[int] = None
123
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
124
125


126
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
127
    object: str = "list"
128
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
129
130


131
132
133
134
class PromptTokenUsageInfo(OpenAIBaseModel):
    cached_tokens: Optional[int] = None


135
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
136
137
138
    prompt_tokens: int = 0
    total_tokens: int = 0
    completion_tokens: Optional[int] = 0
139
    prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
Zhuohan Li's avatar
Zhuohan Li committed
140
141


142
143
144
145
146
class RequestResponseMetadata(BaseModel):
    request_id: str
    final_usage_info: Optional[UsageInfo] = None


147
148
149
150
151
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
    description: Optional[str] = None
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
152
    json_schema: Optional[dict[str, Any]] = Field(default=None, alias='schema')
153
154
155
    strict: Optional[bool] = None


156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class StructuralTag(OpenAIBaseModel):
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
    structural_tag_schema: Optional[dict[str, Any]] = Field(default=None,
                                                            alias="schema")
    end: str


class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    structures: list[StructuralTag]
    triggers: list[str]


171
class ResponseFormat(OpenAIBaseModel):
172
    # type must be "json_schema", "json_object", or "text"
173
174
    type: Literal["text", "json_object", "json_schema"]
    json_schema: Optional[JsonSchemaResponseFormat] = None
175
176


177
178
179
AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat]


180
class StreamOptions(OpenAIBaseModel):
181
    include_usage: Optional[bool] = True
182
    continuous_usage_stats: Optional[bool] = False
183
184


185
186
187
class FunctionDefinition(OpenAIBaseModel):
    name: str
    description: Optional[str] = None
188
    parameters: Optional[dict[str, Any]] = None
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204


class ChatCompletionToolsParam(OpenAIBaseModel):
    type: Literal["function"] = "function"
    function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
    name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
    function: ChatCompletionNamedFunction
    type: Literal["function"] = "function"


205
206
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
207
208
class LogitsProcessorConstructor(BaseModel):
    qualname: str
209
210
    args: Optional[list[Any]] = None
    kwargs: Optional[dict[str, Any]] = None
211

212
213
    model_config = ConfigDict(extra="forbid")

214

215
LogitsProcessors = list[Union[str, LogitsProcessorConstructor]]
216
217
218


def get_logits_processors(processors: Optional[LogitsProcessors],
219
                          pattern: Optional[str]) -> Optional[list[Any]]:
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    if processors and pattern:
        logits_processors = []
        for processor in processors:
            qualname = processor if isinstance(processor,
                                               str) else processor.qualname
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
                    "for more information.")
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
                logits_processor = logits_processor(*processor.args or [],
                                                    **processor.kwargs or {})
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
            "server. See --logits-processor-pattern engine argugment "
            "for more information.")
    return None


249
ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam,
250
                                           ResponseReasoningItem,
251
252
253
                                           ResponseFunctionToolCall]


254
255
256
257
258
259
260
261
262
263
264
265
266
267
class ResponsesRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/responses/create
    background: Optional[bool] = False
    include: Optional[list[
        Literal[
            "code_interpreter_call.outputs",
            "computer_call_output.output.image_url",
            "file_search_call.results",
            "message.input_image.image_url",
            "message.output_text.logprobs",
            "reasoning.encrypted_content",
        ],
    ]] = None
268
    input: Union[str, list[ResponseInputOutputItem]]
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    instructions: Optional[str] = None
    max_output_tokens: Optional[int] = None
    max_tool_calls: Optional[int] = None
    metadata: Optional[Metadata] = None
    model: Optional[str] = None
    parallel_tool_calls: Optional[bool] = True
    previous_response_id: Optional[str] = None
    prompt: Optional[ResponsePrompt] = None
    reasoning: Optional[Reasoning] = None
    service_tier: Literal["auto", "default", "flex", "scale",
                          "priority"] = "auto"
    store: Optional[bool] = True
    stream: Optional[bool] = False
    temperature: Optional[float] = None
    text: Optional[ResponseTextConfig] = None
    tool_choice: ToolChoice = "auto"
    tools: list[Tool] = Field(default_factory=list)
    top_logprobs: Optional[int] = 0
    top_p: Optional[float] = None
    truncation: Optional[Literal["auto", "disabled"]] = "disabled"
    user: Optional[str] = None

    # --8<-- [start:responses-extra-params]
    request_id: str = Field(
        default_factory=lambda: f"resp_{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."),
    )
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."),
    )
310
311
312
313
314
315
316
317
318
    cache_salt: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit). Not supported by vLLM engine V0."))
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    # --8<-- [end:responses-extra-params]

    _DEFAULT_SAMPLING_PARAMS = {
        "temperature": 1.0,
        "top_p": 1.0,
    }

    def to_sampling_params(
        self,
        default_max_tokens: int,
        default_sampling_params: Optional[dict] = None,
    ) -> SamplingParams:
        if self.max_output_tokens is None:
            max_tokens = default_max_tokens
        else:
            max_tokens = min(self.max_output_tokens, default_max_tokens)

        default_sampling_params = default_sampling_params or {}
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
343
        stop_token_ids = default_sampling_params.get("stop_token_ids")
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

        # Structured output
        guided_decoding = None
        if self.text is not None and self.text.format is not None:
            response_format = self.text.format
            if response_format.type == "json_schema":
                guided_decoding = GuidedDecodingParams.from_optional(
                    json=response_format.schema_)
            elif response_format.type == "json_object":
                raise NotImplementedError("json_object is not supported")

        # TODO: add more parameters
        return SamplingParams.from_optional(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            logprobs=self.top_logprobs,
361
            stop_token_ids=stop_token_ids,
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
            output_kind=(RequestOutputKind.DELTA
                         if self.stream else RequestOutputKind.FINAL_ONLY),
            guided_decoding=guided_decoding,
        )

    @model_validator(mode="before")
    def validate_background(cls, data):
        if not data.get("background"):
            return data
        if not data.get("store", True):
            raise ValueError(
                "background can only be used when `store` is true")
        return data

    @model_validator(mode="before")
    def validate_prompt(cls, data):
        if data.get("prompt") is not None:
            raise ValueError("prompt template is not supported")
        return data

382
383
384
385
386
387
388
389
390
391
392
393
394
    @model_validator(mode="before")
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
                    "this instance of vLLM, which uses engine V0.")
            if not isinstance(data["cache_salt"],
                              str) or not data["cache_salt"]:
                raise ValueError("Parameter 'cache_salt' must be a "
                                 "non-empty string if provided.")
        return data

395

396
class ChatCompletionRequest(OpenAIBaseModel):
397
398
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
399
    messages: list[ChatCompletionMessageParam]
400
    model: Optional[str] = None
401
    frequency_penalty: Optional[float] = 0.0
402
    logit_bias: Optional[dict[str, float]] = None
403
    logprobs: Optional[bool] = False
404
    top_logprobs: Optional[int] = 0
405
406
407
408
409
    max_tokens: Optional[int] = Field(
        default=None,
        deprecated=
        'max_tokens is deprecated in favor of the max_completion_tokens field')
    max_completion_tokens: Optional[int] = None
410
411
    n: Optional[int] = 1
    presence_penalty: Optional[float] = 0.0
412
    response_format: Optional[AnyResponseFormat] = None
413
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
414
    stop: Optional[Union[str, list[str]]] = []
Zhuohan Li's avatar
Zhuohan Li committed
415
    stream: Optional[bool] = False
416
    stream_options: Optional[StreamOptions] = None
417
418
    temperature: Optional[float] = None
    top_p: Optional[float] = None
419
    tools: Optional[list[ChatCompletionToolsParam]] = None
420
421
422
423
424
425
    tool_choice: Optional[Union[
        Literal["none"],
        Literal["auto"],
        Literal["required"],
        ChatCompletionNamedToolChoiceParam,
    ]] = "none"
426
427
    reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
    include_reasoning: bool = True
428

429
    # NOTE this will be ignored by vLLM -- the model determines the behavior
430
    parallel_tool_calls: Optional[bool] = False
Zhuohan Li's avatar
Zhuohan Li committed
431
    user: Optional[str] = None
432

433
    # --8<-- [start:chat-completion-sampling-params]
434
    best_of: Optional[int] = None
435
    use_beam_search: bool = False
436
437
438
    top_k: Optional[int] = None
    min_p: Optional[float] = None
    repetition_penalty: Optional[float] = None
439
    length_penalty: float = 1.0
440
    stop_token_ids: Optional[list[int]] = []
441
442
443
444
445
446
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
447
    prompt_logprobs: Optional[int] = None
448
    allowed_token_ids: Optional[list[int]] = None
449
    bad_words: list[str] = Field(default_factory=list)
450
    # --8<-- [end:chat-completion-sampling-params]
451

452
    # --8<-- [start:chat-completion-extra-params]
453
    echo: bool = Field(
454
455
456
457
458
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
            "if they belong to the same role."),
    )
459
    add_generation_prompt: bool = Field(
460
461
462
463
464
465
        default=True,
        description=
        ("If true, the generation prompt will be added to the chat template. "
         "This is a parameter used by chat template in tokenizer config of the "
         "model."),
    )
466
467
468
469
470
471
472
473
474
    continue_final_message: bool = Field(
        default=False,
        description=
        ("If this is set, the chat will be formatted so that the final "
         "message in the chat is open-ended, without any EOS tokens. The "
         "model will continue this message rather than starting a new one. "
         "This allows you to \"prefill\" part of the model's response for it. "
         "Cannot be used at the same time as `add_generation_prompt`."),
    )
475
    add_special_tokens: bool = Field(
476
477
478
479
480
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
481
            "special tokens so this should be set to false (as is the "
482
483
            "default)."),
    )
484
    documents: Optional[list[dict[str, str]]] = Field(
485
486
487
488
489
490
491
492
493
494
495
496
        default=None,
        description=
        ("A list of dicts representing documents that will be accessible to "
         "the model if it is performing RAG (retrieval-augmented generation)."
         " If the template does not support RAG, this argument will have no "
         "effect. We recommend that each document should be a dict containing "
         "\"title\" and \"text\" keys."),
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
497
498
499
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."),
500
    )
501
    chat_template_kwargs: Optional[dict[str, Any]] = Field(
502
        default=None,
503
504
505
        description=(
            "Additional keyword args to pass to the template renderer. "
            "Will be accessible by the chat template."),
506
    )
507
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
508
509
510
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
511
512
513
514
515
516
517
518
519
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
        description=("If specified, the output will follow the JSON schema."),
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the regex pattern."),
    )
520
    guided_choice: Optional[list[str]] = Field(
521
522
523
524
525
526
527
528
529
        default=None,
        description=(
            "If specified, the output will be exactly one of the choices."),
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the context free grammar."),
    )
530
531
532
533
534
    structural_tag: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the structural tag schema."),
    )
535
536
537
538
539
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default guided decoding backend "
            "of the server for this specific request. If set, must be either "
540
541
            "'outlines' / 'lm-format-enforcer'"),
    )
542
543
544
545
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default whitespace pattern "
546
547
            "for guided json decoding."),
    )
548
549
550
551
552
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
553
554
            "if the served model does not use priority scheduling."),
    )
555
556
557
558
559
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
560
561
            "through out the inference process and return in response."),
    )
562
563
564
565
566
567
568
569
570
571
572
    logits_processors: Optional[LogitsProcessors] = Field(
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
            "{'param': 'value'}}."))
573
574
575
576
577
578
    return_tokens_as_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
            "that are not JSON-encodable can be identified."))
579
580
581
582
583
584
585
586
    return_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
            "need to map generated text back to input tokens."))
587
588
589
590
591
592
593
594
595
    cache_salt: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit). Not supported by vLLM engine V0."))
Robert Shaw's avatar
Robert Shaw committed
596
597
598
    kv_transfer_params: Optional[dict[str, Any]] = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.")
599

600
601
602
603
604
605
    vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
        default=None,
        description=("Additional request parameters with string or "
                     "numeric values, used by custom extensions."),
    )

606
    # --8<-- [end:chat-completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
607

608
609
610
611
612
    # Default sampling parameters for chat completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
613
        "top_k": 0,
614
615
616
617
        "min_p": 0.0,
    }

    def to_beam_search_params(
618
619
            self, max_tokens: int,
            default_sampling_params: dict) -> BeamSearchParams:
620
621

        n = self.n if self.n is not None else 1
622
623
624
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
625
626
627
628
629
630

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
631
            length_penalty=self.length_penalty,
632
633
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
634

635
    def to_sampling_params(
636
        self,
637
        max_tokens: int,
638
        logits_processor_pattern: Optional[str],
639
        default_sampling_params: dict,
640
    ) -> SamplingParams:
641

642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])

661
662
663
664
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.top_logprobs

665
        guided_json_object = None
666
667
668
669
670
671
672
        if self.response_format is not None:
            if self.response_format.type == "json_object":
                guided_json_object = True
            elif self.response_format.type == "json_schema":
                json_schema = self.response_format.json_schema
                assert json_schema is not None
                self.guided_json = json_schema.json_schema
673
674
675
676
677
678
            elif self.response_format.type == "structural_tag":
                structural_tag = self.response_format
                assert structural_tag is not None and isinstance(
                    structural_tag, StructuralTagResponseFormat)
                s_tag_obj = structural_tag.model_dump(by_alias=True)
                self.structural_tag = json.dumps(s_tag_obj)
679
680
681
682
683
684
685
686

        guided_decoding = GuidedDecodingParams.from_optional(
            json=self._get_guided_json_from_tool() or self.guided_json,
            regex=self.guided_regex,
            choice=self.guided_choice,
            grammar=self.guided_grammar,
            json_object=guided_json_object,
            backend=self.guided_decoding_backend,
687
            whitespace_pattern=self.guided_whitespace_pattern,
688
            structural_tag=self.structural_tag,
689
        )
690

691
692
693
694
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
695
        return SamplingParams.from_optional(
696
            n=self.n,
697
            best_of=self.best_of,
698
699
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
700
701
702
703
704
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
705
            seed=self.seed,
706
707
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
708
            logprobs=self.top_logprobs if self.logprobs else None,
709
            prompt_logprobs=prompt_logprobs,
710
            ignore_eos=self.ignore_eos,
711
            max_tokens=max_tokens,
712
            min_tokens=self.min_tokens,
713
714
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
715
716
            logits_processors=get_logits_processors(self.logits_processors,
                                                    logits_processor_pattern),
717
            include_stop_str_in_output=self.include_stop_str_in_output,
718
            truncate_prompt_tokens=self.truncate_prompt_tokens,
719
720
            output_kind=RequestOutputKind.DELTA if self.stream \
                else RequestOutputKind.FINAL_ONLY,
721
            guided_decoding=guided_decoding,
Robert Shaw's avatar
Robert Shaw committed
722
            logit_bias=self.logit_bias,
723
            bad_words= self.bad_words,
724
            allowed_token_ids=self.allowed_token_ids,
725
726
            extra_args=extra_args or None,
        )
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743

    def _get_guided_json_from_tool(
            self) -> Optional[Union[str, dict, BaseModel]]:
        # user has chosen to not use any tool
        if self.tool_choice == "none" or self.tools is None:
            return None

        # user has chosen to use a named tool
        if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
            tool_name = self.tool_choice.function.name
            tools = {tool.function.name: tool.function for tool in self.tools}
            if tool_name not in tools:
                raise ValueError(
                    f"Tool '{tool_name}' has not been passed in `tools`.")
            tool = tools[tool_name]
            return tool.parameters

744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
        if self.tool_choice == "required":
            # Pydantic schema generation cannot be used since the JSON schema
            # has to be constructed for a specific instantiation of a tool list
            # so that parameters of a function are correctly generated
            # based on the chosen function name
            def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
                return {
                    "properties": {
                        "name": {
                            "type": "string",
                            "enum": [tool.function.name]
                        },
                        # parameters are always generated as '{}' in the final
                        # output if they are missing from the request
                        # (i.e. are None or '{}') so the schema is
                        # updated to produce an empty object in that case
                        "parameters": tool.function.parameters
                        if tool.function.parameters else {
                            "type": "object",
                            "properties": {}
                        }
                    },
                    "required": ["name", "parameters"]
                }

769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
            def get_tool_schema_defs(
                    tools: list[ChatCompletionToolsParam]) -> dict:
                all_defs = dict[str, dict[str, Any]]()
                for tool in tools:
                    if tool.function.parameters is None:
                        continue
                    defs = tool.function.parameters.pop("$defs", {})
                    for def_name, def_schema in defs.items():
                        if def_name in all_defs and all_defs[
                                def_name] != def_schema:
                            raise ValueError(
                                f"Tool definition '{def_name}' has "
                                "multiple schemas, which is not "
                                "supported.")
                        else:
                            all_defs[def_name] = def_schema
                return all_defs

787
788
789
790
791
792
793
794
            json_schema = {
                "type": "array",
                "minItems": 1,
                "items": {
                    "type": "object",
                    "anyOf": [get_tool_schema(tool) for tool in self.tools]
                }
            }
795
796
797
            json_schema_defs = get_tool_schema_defs(self.tools)
            if json_schema_defs:
                json_schema["$defs"] = json_schema_defs
798
799
            return json_schema

800
        return None
801

802
    @model_validator(mode="before")
803
    @classmethod
804
805
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
806
            raise ValueError(
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
                "Stream options can only be defined when `stream=True`.")

        return data

    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
            if data.get("stream") and prompt_logprobs > 0:
                raise ValueError(
                    "`prompt_logprobs` are not available when `stream=True`.")

            if prompt_logprobs < 0:
                raise ValueError("`prompt_logprobs` must be a positive value.")

        if (top_logprobs := data.get("top_logprobs")) is not None:
            if top_logprobs < 0:
                raise ValueError("`top_logprobs` must be a positive value.")

826
            if top_logprobs > 0 and not data.get("logprobs"):
827
828
829
830
831
                raise ValueError(
                    "when using `top_logprobs`, `logprobs` must be set to true."
                )

        return data
832

833
834
835
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
836
837
838
        if isinstance(data, ValueError):
            raise data

839
840
841
842
843
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
844
        # you can only use one kind of guided decoding
845
846
847
848
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
849
        # you can only either use guided decoding or tools, not both
850
851
852
853
854
        if guide_count > 1 and data.get("tool_choice", "none") not in (
                "none",
                "auto",
                "required",
        ):
855
856
857
858
859
860
            raise ValueError(
                "You can only either use guided decoding or tools, not both.")
        return data

    @model_validator(mode="before")
    @classmethod
861
862
863
864
    def check_tool_usage(cls, data):

        # if "tool_choice" is not specified but tools are provided,
        # default to "auto" tool_choice
865
        if "tool_choice" not in data and data.get("tools"):
866
867
            data["tool_choice"] = "auto"

868
        # if "tool_choice" is "none" -- no validation is needed for tools
869
870
871
        if "tool_choice" in data and data["tool_choice"] == "none":
            return data

872
        # if "tool_choice" is specified -- validation
873
        if "tool_choice" in data and data["tool_choice"] is not None:
874
875

            # ensure that if "tool choice" is specified, tools are present
876
877
878
            if "tools" not in data or data["tools"] is None:
                raise ValueError(
                    "When using `tool_choice`, `tools` must be set.")
879
880

            # make sure that tool choice is either a named tool
881
882
883
884
            # OR that it's set to "auto" or "required"
            if data["tool_choice"] not in [
                    "auto", "required"
            ] and not isinstance(data["tool_choice"], dict):
885
                raise ValueError(
886
887
888
889
                    f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\
                    'Only named tools, "none", "auto" or "required" '\
                    'are supported.'
                )
890

891
892
893
894
895
896
897
898
899
            # if tool_choice is "required" but the "tools" list is empty,
            # override the data to behave like "none" to align with
            # OpenAI’s behavior.
            if data["tool_choice"] == "required" and isinstance(
                    data["tools"], list) and len(data["tools"]) == 0:
                data["tool_choice"] = "none"
                del data["tools"]
                return data

900
901
            # ensure that if "tool_choice" is specified as an object,
            # it matches a valid tool
902
903
            correct_usage_message = 'Correct usage: `{"type": "function",' \
                ' "function": {"name": "my_function"}}`'
904
905
            if isinstance(data["tool_choice"], dict):
                valid_tool = False
906
907
                function = data["tool_choice"].get("function")
                if not isinstance(function, dict):
908
                    raise ValueError(
909
910
911
912
913
914
915
916
                        f"Invalid value for `function`: `{function}` in "
                        f"`tool_choice`! {correct_usage_message}")
                if "name" not in function:
                    raise ValueError(f"Expected field `name` in `function` in "
                                     f"`tool_choice`! {correct_usage_message}")
                function_name = function["name"]
                if not isinstance(function_name,
                                  str) or len(function_name) == 0:
917
                    raise ValueError(
918
919
                        f"Invalid `name` in `function`: `{function_name}`"
                        f" in `tool_choice`! {correct_usage_message}")
920
                for tool in data["tools"]:
921
                    if tool["function"]["name"] == function_name:
922
923
924
925
926
927
                        valid_tool = True
                        break
                if not valid_tool:
                    raise ValueError(
                        "The tool specified in `tool_choice` does not match any"
                        " of the specified `tools`")
928
929
        return data

930
931
932
933
934
935
936
937
938
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
        if data.get("continue_final_message") and data.get(
                "add_generation_prompt"):
            raise ValueError("Cannot set both `continue_final_message` and "
                             "`add_generation_prompt` to True.")
        return data

939
940
941
942
943
944
945
946
947
948
949
950
951
952
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
                    "this instance of vLLM, which uses engine V0.")
            if not isinstance(data["cache_salt"],
                              str) or not data["cache_salt"]:
                raise ValueError("Parameter 'cache_salt' must be a "
                                 "non-empty string if provided.")
        return data

Zhuohan Li's avatar
Zhuohan Li committed
953

954
class CompletionRequest(OpenAIBaseModel):
955
956
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
957
    model: Optional[str] = None
958
959
    prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None
    prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
960
    best_of: Optional[int] = None
Zhuohan Li's avatar
Zhuohan Li committed
961
962
    echo: Optional[bool] = False
    frequency_penalty: Optional[float] = 0.0
963
    logit_bias: Optional[dict[str, float]] = None
964
965
    logprobs: Optional[int] = None
    max_tokens: Optional[int] = 16
966
    n: int = 1
967
    presence_penalty: Optional[float] = 0.0
968
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
969
    stop: Optional[Union[str, list[str]]] = []
970
    stream: Optional[bool] = False
971
    stream_options: Optional[StreamOptions] = None
972
    suffix: Optional[str] = None
973
974
    temperature: Optional[float] = None
    top_p: Optional[float] = None
Zhuohan Li's avatar
Zhuohan Li committed
975
    user: Optional[str] = None
976

977
    # --8<-- [start:completion-sampling-params]
978
    use_beam_search: bool = False
979
980
981
    top_k: Optional[int] = None
    min_p: Optional[float] = None
    repetition_penalty: Optional[float] = None
982
    length_penalty: float = 1.0
983
    stop_token_ids: Optional[list[int]] = []
984
985
986
987
988
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
989
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
990
    allowed_token_ids: Optional[list[int]] = None
991
    prompt_logprobs: Optional[int] = None
992
    # --8<-- [end:completion-sampling-params]
993

994
    # --8<-- [start:completion-extra-params]
995
996
    add_special_tokens: bool = Field(
        default=True,
997
        description=(
998
999
            "If true (the default), special tokens (e.g. BOS) will be added to "
            "the prompt."),
1000
    )
1001
    response_format: Optional[AnyResponseFormat] = Field(
1002
        default=None,
1003
1004
1005
1006
1007
        description=(
            "Similar to chat completion, this parameter specifies the format "
            "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
            ", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
        ),
1008
1009
1010
    )
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
1011
        description="If specified, the output will follow the JSON schema.",
1012
1013
1014
1015
1016
1017
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the regex pattern."),
    )
1018
    guided_choice: Optional[list[str]] = Field(
1019
1020
1021
1022
1023
1024
1025
1026
1027
        default=None,
        description=(
            "If specified, the output will be exactly one of the choices."),
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the context free grammar."),
    )
1028
1029
1030
1031
1032
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default guided decoding backend "
            "of the server for this specific request. If set, must be one of "
1033
1034
            "'outlines' / 'lm-format-enforcer'"),
    )
1035
1036
1037
1038
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default whitespace pattern "
1039
1040
            "for guided json decoding."),
    )
1041
1042
1043
1044
1045
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1046
1047
            "if the served model does not use priority scheduling."),
    )
1048
1049
1050
1051
1052
1053
1054
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."),
    )
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
    logits_processors: Optional[LogitsProcessors] = Field(
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
            "{'param': 'value'}}."))
1066

1067
1068
1069
1070
1071
1072
    return_tokens_as_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
            "that are not JSON-encodable can be identified."))
1073
1074
1075
1076
1077
1078
1079
1080
    return_token_ids: Optional[bool] = Field(
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
            "need to map generated text back to input tokens."))
1081

1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
    cache_salt: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit). Not supported by vLLM engine V0."))

Robert Shaw's avatar
Robert Shaw committed
1092
1093
1094
1095
    kv_transfer_params: Optional[dict[str, Any]] = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.")

1096
1097
1098
1099
1100
1101
    vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
        default=None,
        description=("Additional request parameters with string or "
                     "numeric values, used by custom extensions."),
    )

1102
    # --8<-- [end:completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
1103

1104
1105
1106
1107
1108
    # Default sampling parameters for completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
1109
        "top_k": 0,
1110
1111
1112
1113
        "min_p": 0.0,
    }

    def to_beam_search_params(
1114
1115
1116
        self,
        max_tokens: int,
        default_sampling_params: Optional[dict] = None,
1117
    ) -> BeamSearchParams:
1118

1119
1120
        if default_sampling_params is None:
            default_sampling_params = {}
1121
        n = self.n if self.n is not None else 1
1122
1123
1124

        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get("temperature", 1.0)
1125
1126
1127
1128
1129
1130

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
1131
            length_penalty=self.length_penalty,
1132
1133
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
1134

1135
    def to_sampling_params(
1136
        self,
1137
        max_tokens: int,
1138
1139
1140
        logits_processor_pattern: Optional[str],
        default_sampling_params: Optional[dict] = None,
    ) -> SamplingParams:
1141

1142
1143
        if default_sampling_params is None:
            default_sampling_params = {}
1144

1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])

1164
1165
1166
1167
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

1168
1169
        echo_without_generation = self.echo and self.max_tokens == 0

1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
        guided_json_object = None
        if (self.response_format is not None
                and self.response_format.type == "json_object"):
            guided_json_object = True

        guided_decoding = GuidedDecodingParams.from_optional(
            json=self.guided_json,
            regex=self.guided_regex,
            choice=self.guided_choice,
            grammar=self.guided_grammar,
            json_object=guided_json_object,
            backend=self.guided_decoding_backend,
1182
1183
            whitespace_pattern=self.guided_whitespace_pattern,
        )
1184

1185
1186
1187
1188
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
1189
        return SamplingParams.from_optional(
1190
            n=self.n,
1191
            best_of=self.best_of,
1192
1193
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
1194
1195
1196
1197
1198
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
1199
            seed=self.seed,
1200
1201
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
1202
            logprobs=self.logprobs,
1203
            ignore_eos=self.ignore_eos,
1204
            max_tokens=max_tokens if not echo_without_generation else 1,
1205
            min_tokens=self.min_tokens,
1206
            prompt_logprobs=prompt_logprobs,
1207
            skip_special_tokens=self.skip_special_tokens,
1208
            spaces_between_special_tokens=self.spaces_between_special_tokens,
1209
            include_stop_str_in_output=self.include_stop_str_in_output,
1210
1211
            logits_processors=get_logits_processors(self.logits_processors,
                                                    logits_processor_pattern),
1212
            truncate_prompt_tokens=self.truncate_prompt_tokens,
1213
1214
            output_kind=RequestOutputKind.DELTA if self.stream \
                else RequestOutputKind.FINAL_ONLY,
1215
1216
            guided_decoding=guided_decoding,
            logit_bias=self.logit_bias,
Robert Shaw's avatar
Robert Shaw committed
1217
            allowed_token_ids=self.allowed_token_ids,
1218
1219
            extra_args=extra_args or None,
            )
1220

1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
        return data

1235
1236
1237
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
            if data.get("stream") and prompt_logprobs > 0:
                raise ValueError(
                    "`prompt_logprobs` are not available when `stream=True`.")

            if prompt_logprobs < 0:
                raise ValueError("`prompt_logprobs` must be a positive value.")

        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
            raise ValueError("`logprobs` must be a positive value.")

1249
1250
        return data

1251
1252
1253
1254
1255
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
            raise ValueError(
1256
1257
                "Stream options can only be defined when `stream=True`.")

1258
1259
        return data

1260
1261
1262
1263
1264
1265
1266
1267
    @model_validator(mode="before")
    @classmethod
    def validate_prompt_and_prompt_embeds(cls, data):
        if data.get("prompt") is None and data.get("prompt_embeds") is None:
            raise ValueError(
                "At least one of `prompt` or `prompt_embeds` must be set.")
        return data

1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
        if data.get("cache_salt") is not None:
            if not envs.VLLM_USE_V1:
                raise ValueError(
                    "Parameter 'cache_salt' is not supported with "
                    "this instance of vLLM, which uses engine V0.")
            if not isinstance(data["cache_salt"],
                              str) or not data["cache_salt"]:
                raise ValueError("Parameter 'cache_salt' must be a "
                                 "non-empty string if provided.")
        return data

Zhuohan Li's avatar
Zhuohan Li committed
1282

1283
class EmbeddingCompletionRequest(OpenAIBaseModel):
1284
1285
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/embeddings
1286
    model: Optional[str] = None
1287
    input: Union[list[int], list[list[int]], str, list[str]]
1288
    encoding_format: Literal["float", "base64"] = "float"
1289
1290
    dimensions: Optional[int] = None
    user: Optional[str] = None
1291
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1292

1293
    # --8<-- [start:embedding-extra-params]
1294
1295
1296
1297
1298
1299
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
            "the prompt."),
    )
1300
1301
1302
1303
1304
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1305
1306
            "if the served model does not use priority scheduling."),
    )
1307
1308
1309
1310
1311
1312
1313
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."),
    )
1314
    normalize: Optional[bool] = None
1315

1316
    # --8<-- [end:embedding-extra-params]
1317

1318
    def to_pooling_params(self):
1319
1320
        return PoolingParams(dimensions=self.dimensions,
                             normalize=self.normalize)
1321
1322


1323
class EmbeddingChatRequest(OpenAIBaseModel):
1324
    model: Optional[str] = None
1325
    messages: list[ChatCompletionMessageParam]
1326
1327
1328
1329

    encoding_format: Literal["float", "base64"] = "float"
    dimensions: Optional[int] = None
    user: Optional[str] = None
1330
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1331

1332
    # --8<-- [start:chat-embedding-extra-params]
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
            "default)."),
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."),
    )
1350
    chat_template_kwargs: Optional[dict[str, Any]] = Field(
1351
        default=None,
1352
1353
1354
        description=(
            "Additional keyword args to pass to the template renderer. "
            "Will be accessible by the chat template."),
1355
    )
1356
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
1357
1358
1359
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
1360
1361
1362
1363
1364
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1365
1366
            "if the served model does not use priority scheduling."),
    )
1367
1368
1369
1370
1371
1372
1373
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."),
    )
1374
    normalize: Optional[bool] = None
1375
    # --8<-- [end:chat-embedding-extra-params]
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386

    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
        if data.get("continue_final_message") and data.get(
                "add_generation_prompt"):
            raise ValueError("Cannot set both `continue_final_message` and "
                             "`add_generation_prompt` to True.")
        return data

    def to_pooling_params(self):
1387
1388
        return PoolingParams(dimensions=self.dimensions,
                             normalize=self.normalize)
1389
1390
1391
1392


EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]

1393
1394
1395
1396
PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest]

1397

1398
class ScoreRequest(OpenAIBaseModel):
1399
    model: Optional[str] = None
1400
1401
    text_1: Union[list[str], str, ScoreMultiModalParam]
    text_2: Union[list[str], str, ScoreMultiModalParam]
1402
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1403

1404
    # --8<-- [start:score-extra-params]
1405
1406
1407
1408
1409
1410

    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1411
1412
1413
1414
1415
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1416
1417
            "if the served model does not use priority scheduling."),
    )
1418

1419
1420
    activation: Optional[bool] = None

1421
    # --8<-- [end:score-extra-params]
1422

1423
    def to_pooling_params(self):
1424
        return PoolingParams(activation=self.activation)
1425
1426


1427
class RerankRequest(OpenAIBaseModel):
1428
    model: Optional[str] = None
1429
1430
    query: Union[str, ScoreMultiModalParam]
    documents: Union[list[str], ScoreMultiModalParam]
1431
    top_n: int = Field(default_factory=lambda: 0)
1432
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
1433

1434
    # --8<-- [start:rerank-extra-params]
1435
1436
1437
1438
1439
1440

    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )

1441
1442
1443
1444
1445
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
1446
1447
            "if the served model does not use priority scheduling."),
    )
1448

1449
1450
    activation: Optional[bool] = None

1451
    # --8<-- [end:rerank-extra-params]
1452

1453
    def to_pooling_params(self):
1454
        return PoolingParams(activation=self.activation)
1455
1456
1457


class RerankDocument(BaseModel):
1458
    text: Optional[str] = None
1459
    multi_modal: Optional[ScoreContentPartParam] = None
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475


class RerankResult(BaseModel):
    index: int
    document: RerankDocument
    relevance_score: float


class RerankUsage(BaseModel):
    total_tokens: int


class RerankResponse(OpenAIBaseModel):
    id: str
    model: str
    usage: RerankUsage
1476
    results: list[RerankResult]
1477
1478


1479
class CompletionLogProbs(OpenAIBaseModel):
1480
1481
1482
1483
    text_offset: list[int] = Field(default_factory=list)
    token_logprobs: list[Optional[float]] = Field(default_factory=list)
    tokens: list[str] = Field(default_factory=list)
    top_logprobs: list[Optional[dict[str,
1484
                                     float]]] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
1485
1486


1487
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1488
1489
    index: int
    text: str
1490
    logprobs: Optional[CompletionLogProbs] = None
1491
1492
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
1493
1494
1495
1496
1497
1498
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
            "including encountering the EOS token"),
    )
1499
    token_ids: Optional[list[int]] = None  # For response
1500
    prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
1501
    prompt_token_ids: Optional[list[int]] = None  # For prompt
Zhuohan Li's avatar
Zhuohan Li committed
1502
1503


1504
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1505
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
1506
    object: Literal["text_completion"] = "text_completion"
Zhuohan Li's avatar
Zhuohan Li committed
1507
1508
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1509
    choices: list[CompletionResponseChoice]
1510
1511
1512
    service_tier: Optional[Literal["auto", "default", "flex", "scale",
                                   "priority"]] = None
    system_fingerprint: Optional[str] = None
Zhuohan Li's avatar
Zhuohan Li committed
1513
    usage: UsageInfo
1514
1515

    # vLLM-specific fields that are not in OpenAI spec
Robert Shaw's avatar
Robert Shaw committed
1516
1517
    kv_transfer_params: Optional[dict[str, Any]] = Field(
        default=None, description="KVTransfer parameters.")
Zhuohan Li's avatar
Zhuohan Li committed
1518
1519


1520
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1521
1522
    index: int
    text: str
1523
    logprobs: Optional[CompletionLogProbs] = None
1524
1525
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
1526
1527
1528
1529
1530
1531
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
            "including encountering the EOS token"),
    )
1532
1533
1534
1535
    # not part of the OpenAI spec but for tracing the tokens
    # prompt tokens is put into choice to align with CompletionResponseChoice
    prompt_token_ids: Optional[list[int]] = None
    token_ids: Optional[list[int]] = None
Zhuohan Li's avatar
Zhuohan Li committed
1536
1537


1538
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
1539
1540
1541
1542
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1543
    choices: list[CompletionResponseStreamChoice]
1544
    usage: Optional[UsageInfo] = Field(default=None)
1545
1546


1547
class EmbeddingResponseData(OpenAIBaseModel):
1548
1549
    index: int
    object: str = "embedding"
1550
    embedding: Union[list[float], str]
1551
1552


1553
class EmbeddingResponse(OpenAIBaseModel):
1554
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
1555
1556
1557
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1558
    data: list[EmbeddingResponseData]
1559
1560
1561
    usage: UsageInfo


1562
1563
1564
class PoolingResponseData(OpenAIBaseModel):
    index: int
    object: str = "pooling"
1565
    data: Union[list[list[float]], list[float], str]
1566
1567
1568
1569
1570
1571
1572


class PoolingResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1573
    data: list[PoolingResponseData]
1574
1575
1576
    usage: UsageInfo


1577
1578
1579
class ScoreResponseData(OpenAIBaseModel):
    index: int
    object: str = "score"
1580
    score: float
1581
1582
1583
1584
1585
1586
1587


class ScoreResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1588
    data: list[ScoreResponseData]
1589
1590
1591
    usage: UsageInfo


1592
1593
1594
1595
1596
1597
class ClassificationRequest(OpenAIBaseModel):
    model: Optional[str] = None
    input: Union[list[str], str]
    truncate_prompt_tokens: Optional[int] = None
    user: Optional[str] = None

1598
    # --8<-- [start:classification-extra-params]
1599
1600
1601
1602
1603
1604
1605
1606
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."),
    )

1607
1608
    activation: Optional[bool] = None

1609
    # --8<-- [end:classification-extra-params]
1610
1611

    def to_pooling_params(self):
1612
        return PoolingParams(activation=self.activation)
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630


class ClassificationData(OpenAIBaseModel):
    index: int
    label: Optional[str]
    probs: list[float]
    num_classes: int


class ClassificationResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[ClassificationData]
    usage: UsageInfo


1631
1632
1633
1634
1635
1636
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
1637
    id: str = Field(default_factory=random_tool_call_id)
1638
1639
1640
1641
    type: Literal["function"] = "function"
    function: FunctionCall


1642
1643
1644
1645
1646
1647
1648
class DeltaFunctionCall(BaseModel):
    name: Optional[str] = None
    arguments: Optional[str] = None


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
1649
1650
    id: Optional[str] = None
    type: Optional[Literal["function"]] = None
1651
1652
1653
1654
1655
1656
1657
1658
1659
    index: int
    function: Optional[DeltaFunctionCall] = None


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
1660
    tool_calls: list[ToolCall]
1661
1662
1663
1664
1665
1666

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
    content: Optional[str] = None


1667
class ChatMessage(OpenAIBaseModel):
1668
    role: str
1669
    content: Optional[str] = None
1670
1671
1672
1673
    refusal: Optional[str] = None
    annotations: Optional[OpenAIAnnotation] = None
    audio: Optional[OpenAIChatCompletionAudio] = None
    function_call: Optional[FunctionCall] = None
1674
    tool_calls: list[ToolCall] = Field(default_factory=list)
1675

1676
1677
1678
    # vLLM-specific fields that are not in OpenAI spec
    reasoning_content: Optional[str] = None

1679

1680
1681
1682
class ChatCompletionLogProb(OpenAIBaseModel):
    token: str
    logprob: float = -9999.0
1683
    bytes: Optional[list[int]] = None
1684
1685
1686


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
1687
1688
1689
    # Workaround: redefine fields name cache so that it's not
    # shared with the super class.
    field_names: ClassVar[Optional[set[str]]] = None
1690
    top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
1691
1692
1693


class ChatCompletionLogProbs(OpenAIBaseModel):
1694
    content: Optional[list[ChatCompletionLogProbsContent]] = None
1695
1696


1697
class ChatCompletionResponseChoice(OpenAIBaseModel):
1698
1699
    index: int
    message: ChatMessage
1700
    logprobs: Optional[ChatCompletionLogProbs] = None
1701
1702
1703
    # per OpenAI spec this is the default
    finish_reason: Optional[str] = "stop"
    # not part of the OpenAI spec but included in vLLM for legacy reasons
1704
    stop_reason: Optional[Union[int, str]] = None
1705
1706
1707
    # not part of the OpenAI spec but is useful for tracing the tokens
    # in agent scenarios
    token_ids: Optional[list[int]] = None
1708
1709


1710
class ChatCompletionResponse(OpenAIBaseModel):
1711
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1712
    object: Literal["chat.completion"] = "chat.completion"
1713
1714
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1715
    choices: list[ChatCompletionResponseChoice]
1716
1717
1718
    service_tier: Optional[Literal["auto", "default", "flex", "scale",
                                   "priority"]] = None
    system_fingerprint: Optional[str] = None
1719
    usage: UsageInfo
1720
1721

    # vLLM-specific fields that are not in OpenAI spec
1722
    prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
1723
    prompt_token_ids: Optional[list[int]] = None
Robert Shaw's avatar
Robert Shaw committed
1724
1725
    kv_transfer_params: Optional[dict[str, Any]] = Field(
        default=None, description="KVTransfer parameters.")
1726
1727


1728
class DeltaMessage(OpenAIBaseModel):
1729
1730
    role: Optional[str] = None
    content: Optional[str] = None
1731
    reasoning_content: Optional[str] = None
1732
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
1733
1734


1735
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
1736
1737
    index: int
    delta: DeltaMessage
1738
    logprobs: Optional[ChatCompletionLogProbs] = None
1739
    finish_reason: Optional[str] = None
1740
    stop_reason: Optional[Union[int, str]] = None
1741
1742
    # not part of the OpenAI spec but for tracing the tokens
    token_ids: Optional[list[int]] = None
1743
1744


1745
class ChatCompletionStreamResponse(OpenAIBaseModel):
1746
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1747
    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
1748
1749
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
1750
    choices: list[ChatCompletionResponseStreamChoice]
1751
    usage: Optional[UsageInfo] = Field(default=None)
1752
1753
    # not part of the OpenAI spec but for tracing the tokens
    prompt_token_ids: Optional[list[int]] = None
1754
1755


1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = None


class TranscriptionStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
    object: Literal["transcription.chunk"] = "transcription.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranscriptionResponseStreamChoice]
    usage: Optional[UsageInfo] = Field(default=None)


1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
class InputTokensDetails(OpenAIBaseModel):
    cached_tokens: int


class OutputTokensDetails(OpenAIBaseModel):
    reasoning_tokens: int


class ResponseUsage(OpenAIBaseModel):
    input_tokens: int
    input_tokens_details: InputTokensDetails
    output_tokens: int
    output_tokens_details: OutputTokensDetails
    total_tokens: int
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795


class ResponsesResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
    created_at: int = Field(default_factory=lambda: int(time.time()))
    # error: Optional[ResponseError] = None
    # incomplete_details: Optional[IncompleteDetails] = None
    instructions: Optional[str] = None
    metadata: Optional[Metadata] = None
    model: str
    object: Literal["response"] = "response"
1796
    output: list[ResponseOutputItem]
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
    parallel_tool_calls: bool
    temperature: float
    tool_choice: ToolChoice
    tools: list[Tool]
    top_p: float
    background: bool
    max_output_tokens: int
    max_tool_calls: Optional[int] = None
    previous_response_id: Optional[str] = None
    prompt: Optional[ResponsePrompt] = None
    reasoning: Optional[Reasoning] = None
    service_tier: Literal["auto", "default", "flex", "scale", "priority"]
    status: ResponseStatus
    text: Optional[ResponseTextConfig] = None
    top_logprobs: int
    truncation: Literal["auto", "disabled"]
1813
    usage: Optional[ResponseUsage] = None
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
    user: Optional[str] = None

    @classmethod
    def from_request(
        cls,
        request: ResponsesRequest,
        sampling_params: SamplingParams,
        model_name: str,
        created_time: int,
        output: list[ResponseOutputItem],
        status: ResponseStatus,
1825
        usage: Optional[ResponseUsage] = None,
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
    ) -> "ResponsesResponse":
        return cls(
            id=request.request_id,
            created_at=created_time,
            instructions=request.instructions,
            metadata=request.metadata,
            model=model_name,
            output=output,
            parallel_tool_calls=request.parallel_tool_calls,
            temperature=sampling_params.temperature,
            tool_choice=request.tool_choice,
            tools=request.tools,
            top_p=sampling_params.top_p,
            background=request.background,
            max_output_tokens=sampling_params.max_tokens,
            max_tool_calls=request.max_tool_calls,
            previous_response_id=request.previous_response_id,
            prompt=request.prompt,
            reasoning=request.reasoning,
            service_tier=request.service_tier,
            status=status,
            text=request.text,
            top_logprobs=sampling_params.logprobs,
            truncation=request.truncation,
            user=request.user,
            usage=usage,
        )


1855
1856
1857
1858
BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest,
                              ScoreRequest, RerankRequest]


1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
class BatchRequestInput(OpenAIBaseModel):
    """
    The per-line object of the batch input file.

    NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
    """

    # A developer-provided per-request id that will be used to match outputs to
    # inputs. Must be unique for each request in a batch.
    custom_id: str

    # The HTTP method to be used for the request. Currently only POST is
    # supported.
    method: str

    # The OpenAI API relative URL to be used for the request. Currently
    # /v1/chat/completions is supported.
    url: str

1878
    # The parameters of the request.
1879
    body: BatchRequestInputBody
1880

1881
1882
1883
1884
    @field_validator('body', mode='plain')
    @classmethod
    def check_type_for_url(cls, value: Any, info: ValidationInfo):
        # Use url to disambiguate models
1885
        url: str = info.data["url"]
1886
1887
1888
1889
        if url == "/v1/chat/completions":
            return ChatCompletionRequest.model_validate(value)
        if url == "/v1/embeddings":
            return TypeAdapter(EmbeddingRequest).validate_python(value)
1890
        if url.endswith("/score"):
1891
            return ScoreRequest.model_validate(value)
1892
1893
1894
        if url.endswith("/rerank"):
            return RerankRequest.model_validate(value)
        return TypeAdapter(BatchRequestInputBody).validate_python(value)
1895

1896

1897
1898
1899
1900
1901
1902
1903
1904
class BatchResponseData(OpenAIBaseModel):
    # HTTP status code of the response.
    status_code: int = 200

    # An unique identifier for the API request.
    request_id: str

    # The body of the response.
1905
    body: Optional[Union[ChatCompletionResponse, EmbeddingResponse,
1906
                         ScoreResponse, RerankResponse]] = None
1907
1908


1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
class BatchRequestOutput(OpenAIBaseModel):
    """
    The per-line object of the batch output and error files
    """

    id: str

    # A developer-provided per-request id that will be used to match outputs to
    # inputs.
    custom_id: str

1920
    response: Optional[BatchResponseData]
1921
1922
1923
1924

    # For requests that failed with a non-HTTP error, this will contain more
    # information on the cause of the failure.
    error: Optional[Any]
1925
1926


1927
class TokenizeCompletionRequest(OpenAIBaseModel):
1928
    model: Optional[str] = None
1929
1930
    prompt: str

1931
1932
1933
1934
1935
1936
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
            "the prompt."),
    )
1937
1938
1939
1940
1941
    return_token_strs: Optional[bool] = Field(
        default=False,
        description=("If true, also return the token strings "
                     "corresponding to the token ids."),
    )
1942
1943
1944


class TokenizeChatRequest(OpenAIBaseModel):
1945
    model: Optional[str] = None
1946
    messages: list[ChatCompletionMessageParam]
1947

1948
1949
1950
1951
1952
1953
1954
    add_generation_prompt: bool = Field(
        default=True,
        description=
        ("If true, the generation prompt will be added to the chat template. "
         "This is a parameter used by chat template in tokenizer config of the "
         "model."),
    )
1955
1956
1957
1958
1959
    return_token_strs: Optional[bool] = Field(
        default=False,
        description=("If true, also return the token strings "
                     "corresponding to the token ids."),
    )
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
    continue_final_message: bool = Field(
        default=False,
        description=
        ("If this is set, the chat will be formatted so that the final "
         "message in the chat is open-ended, without any EOS tokens. The "
         "model will continue this message rather than starting a new one. "
         "This allows you to \"prefill\" part of the model's response for it. "
         "Cannot be used at the same time as `add_generation_prompt`."),
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
            "default)."),
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."),
    )
1986
    chat_template_kwargs: Optional[dict[str, Any]] = Field(
1987
        default=None,
1988
1989
1990
        description=(
            "Additional keyword args to pass to the template renderer. "
            "Will be accessible by the chat template."),
1991
    )
1992
    mm_processor_kwargs: Optional[dict[str, Any]] = Field(
1993
1994
1995
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
1996
1997
1998
1999
    tools: Optional[list[ChatCompletionToolsParam]] = Field(
        default=None,
        description=("A list of tools the model may call."),
    )
2000

2001
2002
2003
2004
2005
2006
2007
2008
2009
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
        if data.get("continue_final_message") and data.get(
                "add_generation_prompt"):
            raise ValueError("Cannot set both `continue_final_message` and "
                             "`add_generation_prompt` to True.")
        return data

2010
2011

TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
2012
2013
2014
2015
2016


class TokenizeResponse(OpenAIBaseModel):
    count: int
    max_model_len: int
2017
    tokens: list[int]
2018
    token_strs: Optional[list[str]] = None
2019
2020
2021


class DetokenizeRequest(OpenAIBaseModel):
2022
    model: Optional[str] = None
2023
    tokens: list[int]
2024
2025
2026
2027


class DetokenizeResponse(OpenAIBaseModel):
    prompt: str
2028
2029


2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
class TokenizerInfoResponse(OpenAIBaseModel):
    """
    Response containing tokenizer configuration 
    equivalent to tokenizer_config.json
    """

    model_config = ConfigDict(extra="allow")
    tokenizer_class: str


2040
class LoadLoRAAdapterRequest(BaseModel):
2041
2042
2043
2044
    lora_name: str
    lora_path: str


2045
class UnloadLoRAAdapterRequest(BaseModel):
2046
2047
    lora_name: str
    lora_int_id: Optional[int] = Field(default=None)
2048
2049
2050
2051
2052
2053
2054
2055
2056


## Protocols for Audio
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json",
                                         "vtt"]


class TranscriptionRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
2057
    # https://platform.openai.com/docs/api-reference/audio/createTranscription
2058
2059
2060
2061
2062
2063
2064

    file: UploadFile
    """
    The audio file object (not file name) to transcribe, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

2065
    model: Optional[str] = None
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
    """ID of the model to use.
    """

    language: Optional[str] = None
    """The language of the input audio.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy and latency.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    ## TODO (varun) : Support if set to 0, certain thresholds are met !!

2093
    timestamp_granularities: list[Literal["word", "segment"]] = Field(
2094
2095
2096
2097
2098
2099
2100
2101
2102
        alias="timestamp_granularities[]", default=[])
    """The timestamp granularities to populate for this transcription.

    `response_format` must be set `verbose_json` to use timestamp granularities.
    Either or both of these options are supported: `word`, or `segment`. Note:
    There is no additional latency for segment timestamps, but generating word
    timestamps incurs additional latency.
    """

2103
    stream: Optional[bool] = False
2104
    """When set, it will enable output to be streamed in a similar fashion
2105
    as the Chat Completion endpoint.
2106
    """
2107
    # --8<-- [start:transcription-extra-params]
2108
2109
2110
    # Flattened stream option to simplify form data.
    stream_include_usage: Optional[bool] = False
    stream_continuous_usage_stats: Optional[bool] = False
2111
2112
2113
2114
2115
2116

    vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
        default=None,
        description=("Additional request parameters with string or "
                     "numeric values, used by custom extensions."),
    )
2117
    # --8<-- [end:transcription-extra-params]
2118

2119
    # --8<-- [start:transcription-sampling-params]
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """

    top_p: Optional[float] = None
2130
    """Enables nucleus (top-p) sampling, where tokens are selected from the
2131
2132
2133
2134
2135
2136
2137
    smallest possible set whose cumulative probability exceeds `p`.
    """

    top_k: Optional[int] = None
    """Limits sampling to the `k` most probable tokens at each step."""

    min_p: Optional[float] = None
2138
    """Filters out tokens with a probability lower than `min_p`, ensuring a
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
    minimum likelihood threshold during sampling.
    """

    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    """The seed to use for sampling."""

    frequency_penalty: Optional[float] = 0.0
    """The frequency penalty to use for sampling."""

    repetition_penalty: Optional[float] = None
    """The repetition penalty to use for sampling."""

    presence_penalty: Optional[float] = 0.0
    """The presence penalty to use for sampling."""
2153
    # --8<-- [end:transcription-sampling-params]
2154

2155
2156
    # Default sampling parameters for transcription requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
2157
2158
2159
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
2160
        "top_k": 0,
2161
        "min_p": 0.0,
2162
2163
2164
2165
2166
2167
    }

    def to_sampling_params(
            self,
            default_max_tokens: int,
            default_sampling_params: Optional[dict] = None) -> SamplingParams:
2168

2169
2170
2171
2172
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
2173

2174
2175
2176
2177
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])

        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"])
2192
2193

        return SamplingParams.from_optional(temperature=temperature,
2194
                                            max_tokens=max_tokens,
2195
2196
2197
2198
2199
2200
2201
                                            seed=self.seed,
                                            top_p=top_p,
                                            top_k=top_k,
                                            min_p=min_p,
                                            frequency_penalty=self.frequency_penalty,
                                            repetition_penalty=repetition_penalty,
                                            presence_penalty=self.presence_penalty,
2202
2203
                                            output_kind=RequestOutputKind.DELTA
                                            if self.stream \
2204
2205
                                            else RequestOutputKind.FINAL_ONLY,
                                            extra_args=self.vllm_xargs)
2206
2207
2208

    @model_validator(mode="before")
    @classmethod
2209
2210
2211
2212
2213
2214
2215
    def validate_transcription_request(cls, data):
        if isinstance(data.get("file"), str):
            raise HTTPException(
                status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
                detail="Expected 'file' to be a file-like object, not 'str'.",
            )

2216
2217
2218
2219
2220
2221
2222
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
            raise ValueError(
                "Stream options can only be defined when `stream=True`.")

        return data
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279


# Transcription response objects
class TranscriptionResponse(OpenAIBaseModel):
    text: str
    """The transcribed text."""


class TranscriptionWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranscriptionSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

2280
    tokens: list[int]
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
    """Array of token IDs for the text content."""


class TranscriptionResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The transcribed text."""

2294
    segments: Optional[list[TranscriptionSegment]] = None
2295
2296
    """Segments of the transcribed text and their corresponding details."""

2297
    words: Optional[list[TranscriptionWord]] = None
2298
    """Extracted words and their corresponding timestamps."""
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365


class TranslationResponseStreamChoice(OpenAIBaseModel):
    delta: DeltaMessage
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = None


class TranslationStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
    object: Literal["translation.chunk"] = "translation.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[TranslationResponseStreamChoice]
    usage: Optional[UsageInfo] = Field(default=None)


class TranslationRequest(OpenAIBaseModel):
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/audio/createTranslation

    file: UploadFile
    """
    The audio file object (not file name) to translate, in one of these
    formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    """

    model: Optional[str] = None
    """ID of the model to use.
    """

    prompt: str = Field(default="")
    """An optional text to guide the model's style or continue a previous audio
    segment.

    The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
    should match the audio language.
    """

    response_format: AudioResponseFormat = Field(default="json")
    """
    The format of the output, in one of these options: `json`, `text`, `srt`,
    `verbose_json`, or `vtt`.
    """

    # TODO support additional sampling parameters
    # --8<-- [start:translation-sampling-params]
    temperature: float = Field(default=0.0)
    """The sampling temperature, between 0 and 1.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused / deterministic. If set to 0, the model
    will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
    to automatically increase the temperature until certain thresholds are hit.
    """
    # --8<-- [end:translation-sampling-params]

    # --8<-- [start:translation-extra-params]
    language: Optional[str] = None
    """The language of the input audio we translate from.

    Supplying the input language in
    [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
    will improve accuracy.
    """

    stream: Optional[bool] = False
2366
    """Custom field not present in the original OpenAI definition. When set,
2367
    it will enable output to be streamed in a similar fashion as the Chat
2368
    Completion endpoint.
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
    """
    # Flattened stream option to simplify form data.
    stream_include_usage: Optional[bool] = False
    stream_continuous_usage_stats: Optional[bool] = False
    # --8<-- [end:translation-extra-params]

    # Default sampling parameters for translation requests.
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "temperature": 0,
    }

    def to_sampling_params(
            self,
            default_max_tokens: int,
            default_sampling_params: Optional[dict] = None) -> SamplingParams:
2384

2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
        max_tokens = default_max_tokens

        if default_sampling_params is None:
            default_sampling_params = {}
        # Default parameters
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])

        return SamplingParams.from_optional(temperature=temperature,
                                            max_tokens=max_tokens,
                                            output_kind=RequestOutputKind.DELTA
                                            if self.stream \
                                            else RequestOutputKind.FINAL_ONLY)

    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
        stream = data.get("stream", False)
        if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
            raise ValueError(
                "Stream options can only be defined when `stream=True`.")

        return data


# Translation response objects
class TranslationResponse(OpenAIBaseModel):
    text: str
    """The translated text."""


class TranslationWord(OpenAIBaseModel):
    end: float
    """End time of the word in seconds."""

    start: float
    """Start time of the word in seconds."""

    word: str
    """The text content of the word."""


class TranslationSegment(OpenAIBaseModel):
    id: int
    """Unique identifier of the segment."""

    avg_logprob: float
    """Average logprob of the segment.

    If the value is lower than -1, consider the logprobs failed.
    """

    compression_ratio: float
    """Compression ratio of the segment.

    If the value is greater than 2.4, consider the compression failed.
    """

    end: float
    """End time of the segment in seconds."""

    no_speech_prob: float
    """Probability of no speech in the segment.

    If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
    this segment silent.
    """

    seek: int
    """Seek offset of the segment."""

    start: float
    """Start time of the segment in seconds."""

    temperature: float
    """Temperature parameter used for generating the segment."""

    text: str
    """Text content of the segment."""

    tokens: list[int]
    """Array of token IDs for the text content."""


class TranslationResponseVerbose(OpenAIBaseModel):
    duration: str
    """The duration of the input audio."""

    language: str
    """The language of the input audio."""

    text: str
    """The translated text."""

    segments: Optional[list[TranslationSegment]] = None
    """Segments of the translated text and their corresponding details."""

    words: Optional[list[TranslationWord]] = None
    """Extracted words and their corresponding timestamps."""