protocol.py 9.61 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
Zhuohan Li's avatar
Zhuohan Li committed
6
import time
7
from typing import Any, ClassVar, Literal, TypeAlias
Zhuohan Li's avatar
Zhuohan Li committed
8

9
import regex as re
10
11
12
13
14
15
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    model_validator,
)
Zhuohan Li's avatar
Zhuohan Li committed
16

17
from vllm.entrypoints.chat_utils import make_tool_call_id
18
from vllm.logger import init_logger
19
from vllm.sampling_params import SamplingParams
20
21
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
22

23
24
logger = init_logger(__name__)

Zhuohan Li's avatar
Zhuohan Li committed
25

26
class OpenAIBaseModel(BaseModel):
27
28
29
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

30
    # Cache class field names
31
    field_names: ClassVar[set[str] | None] = None
32

33
    @model_validator(mode="wrap")
34
    @classmethod
35
36
37
38
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
39
40
        field_names = cls.field_names
        if field_names is None:
41
42
43
44
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
45
                if alias := getattr(field, "alias", None):
46
47
48
49
50
51
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
52
                "The following fields were present in the request but ignored: %s",
53
54
                data.keys() - field_names,
            )
55
        return result
56
57


58
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
59
60
    message: str
    type: str
61
    param: str | None = None
62
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
63
64


65
66
67
68
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


69
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
70
71
72
73
74
75
76
77
78
79
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
80
    group: str | None = None
81
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
82
83


84
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
85
86
87
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
88
    owned_by: str = "vllm"
89
90
91
    root: str | None = None
    parent: str | None = None
    max_model_len: int | None = None
92
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
93
94


95
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
96
    object: str = "list"
97
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
98
99


100
class PromptTokenUsageInfo(OpenAIBaseModel):
101
    cached_tokens: int | None = None
102
103


104
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
105
106
    prompt_tokens: int = 0
    total_tokens: int = 0
107
108
    completion_tokens: int | None = 0
    prompt_tokens_details: PromptTokenUsageInfo | None = None
Zhuohan Li's avatar
Zhuohan Li committed
109
110


111
112
class RequestResponseMetadata(BaseModel):
    request_id: str
113
    final_usage_info: UsageInfo | None = None
114
115


116
117
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
118
    description: str | None = None
119
120
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
121
122
    json_schema: dict[str, Any] | None = Field(default=None, alias="schema")
    strict: bool | None = None
123
124


125
class LegacyStructuralTag(OpenAIBaseModel):
126
127
128
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
129
    structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema")
130
131
132
    end: str


133
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
134
    type: Literal["structural_tag"]
135
    structures: list[LegacyStructuralTag]
136
137
138
    triggers: list[str]


139
140
141
142
143
144
145
146
147
148
class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    format: Any


AnyStructuralTagResponseFormat: TypeAlias = (
    LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
)


149
class ResponseFormat(OpenAIBaseModel):
150
    # type must be "json_schema", "json_object", or "text"
151
    type: Literal["text", "json_object", "json_schema"]
152
    json_schema: JsonSchemaResponseFormat | None = None
153
154


155
156
157
AnyResponseFormat: TypeAlias = (
    ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
)
158
159


160
class StreamOptions(OpenAIBaseModel):
161
162
    include_usage: bool | None = True
    continuous_usage_stats: bool | None = False
163
164


165
166
class FunctionDefinition(OpenAIBaseModel):
    name: str
167
168
    description: str | None = None
    parameters: dict[str, Any] | None = None
169
170


171
172
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
173
174
class LogitsProcessorConstructor(BaseModel):
    qualname: str
175
176
    args: list[Any] | None = None
    kwargs: dict[str, Any] | None = None
177

178
179
    model_config = ConfigDict(extra="forbid")

180

181
LogitsProcessors = list[str | LogitsProcessorConstructor]
182
183


184
def get_logits_processors(
185
186
    processors: LogitsProcessors | None, pattern: str | None
) -> list[Any] | None:
187
188
189
    if processors and pattern:
        logits_processors = []
        for processor in processors:
190
            qualname = processor if isinstance(processor, str) else processor.qualname
191
192
193
194
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
195
196
                    "for more information."
                )
197
198
199
200
201
202
203
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
204
205
206
                logits_processor = logits_processor(
                    *processor.args or [], **processor.kwargs or {}
                )
207
208
209
210
211
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
212
            "server. See --logits-processor-pattern engine argument "
213
214
            "for more information."
        )
215
216
217
    return None


218
class FunctionCall(OpenAIBaseModel):
219
220
221
222
    # Internal field to preserve native tool call ID from tool parser.
    # Excluded from serialization to maintain OpenAI API compatibility
    # (function object should only contain 'name' and 'arguments').
    id: str | None = Field(default=None, exclude=True)
223
224
225
226
227
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
228
    id: str = Field(default_factory=make_tool_call_id)
229
230
231
232
    type: Literal["function"] = "function"
    function: FunctionCall


233
class DeltaFunctionCall(BaseModel):
234
235
    name: str | None = None
    arguments: str | None = None
236
237
238
239


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
240
241
    id: str | None = None
    type: Literal["function"] | None = None
242
    index: int
243
    function: DeltaFunctionCall | None = None
244
245
246
247
248
249
250


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
251
    tool_calls: list[ToolCall]
252
253
254

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
255
    content: str | None = None
256
257


258
class DeltaMessage(OpenAIBaseModel):
259
260
    role: str | None = None
    content: str | None = None
261
    reasoning: str | None = None
262
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
263
264


265
266
267
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
    request_id: str = Field(
268
        default_factory=random_uuid,
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    token_ids: list[int]
    """The token ids to generate text from."""

    # features: MultiModalFeatureSpec
    # TODO (NickLucche): implement once Renderer work is completed
    features: str | None = None
    """The processed MM inputs for the model."""

    sampling_params: SamplingParams
    """The sampling parameters for the model."""

    model: str | None = None

    stream: bool | None = False
    stream_options: StreamOptions | None = None
    cache_salt: str | None = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit)."
        ),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."
        ),
    )
    kv_transfer_params: dict[str, Any] | None = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.",
    )