protocol.py 9.93 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
Zhuohan Li's avatar
Zhuohan Li committed
6
import time
7
from typing import Any, ClassVar, Literal, TypeAlias
Zhuohan Li's avatar
Zhuohan Li committed
8

9
import regex as re
10
11
12
13
14
15
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    model_validator,
)
Zhuohan Li's avatar
Zhuohan Li committed
16

17
from vllm.entrypoints.chat_utils import make_tool_call_id
18
from vllm.logger import init_logger
19
20
21
from vllm.sampling_params import (
    SamplingParams,
)
22
23
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
24

25
26
logger = init_logger(__name__)

Zhuohan Li's avatar
Zhuohan Li committed
27

28
class OpenAIBaseModel(BaseModel):
29
30
31
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

32
    # Cache class field names
33
    field_names: ClassVar[set[str] | None] = None
34

35
    @model_validator(mode="wrap")
36
    @classmethod
37
38
39
40
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
41
42
        field_names = cls.field_names
        if field_names is None:
43
44
45
46
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
47
                if alias := getattr(field, "alias", None):
48
49
50
51
52
53
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
54
                "The following fields were present in the request but ignored: %s",
55
56
                data.keys() - field_names,
            )
57
        return result
58
59


60
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
61
62
    message: str
    type: str
63
    param: str | None = None
64
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
65
66


67
68
69
70
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


71
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
72
73
74
75
76
77
78
79
80
81
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
82
    group: str | None = None
83
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
84
85


86
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
87
88
89
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
90
    owned_by: str = "vllm"
91
92
93
    root: str | None = None
    parent: str | None = None
    max_model_len: int | None = None
94
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
95
96


97
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
98
    object: str = "list"
99
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
100
101


102
class PromptTokenUsageInfo(OpenAIBaseModel):
103
    cached_tokens: int | None = None
104
105


106
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
107
108
    prompt_tokens: int = 0
    total_tokens: int = 0
109
110
    completion_tokens: int | None = 0
    prompt_tokens_details: PromptTokenUsageInfo | None = None
Zhuohan Li's avatar
Zhuohan Li committed
111
112


113
114
class RequestResponseMetadata(BaseModel):
    request_id: str
115
    final_usage_info: UsageInfo | None = None
116
117


118
119
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
120
    description: str | None = None
121
122
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
123
124
    json_schema: dict[str, Any] | None = Field(default=None, alias="schema")
    strict: bool | None = None
125
126


127
class LegacyStructuralTag(OpenAIBaseModel):
128
129
130
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
131
    structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema")
132
133
134
    end: str


135
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
136
    type: Literal["structural_tag"]
137
    structures: list[LegacyStructuralTag]
138
139
140
    triggers: list[str]


141
142
143
144
145
146
147
148
149
150
class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    format: Any


AnyStructuralTagResponseFormat: TypeAlias = (
    LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
)


151
class ResponseFormat(OpenAIBaseModel):
152
    # type must be "json_schema", "json_object", or "text"
153
    type: Literal["text", "json_object", "json_schema"]
154
    json_schema: JsonSchemaResponseFormat | None = None
155
156


157
158
159
AnyResponseFormat: TypeAlias = (
    ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
)
160
161


162
class StreamOptions(OpenAIBaseModel):
163
164
    include_usage: bool | None = True
    continuous_usage_stats: bool | None = False
165
166


167
168
class FunctionDefinition(OpenAIBaseModel):
    name: str
169
170
    description: str | None = None
    parameters: dict[str, Any] | None = None
171
172


173
174
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
175
176
class LogitsProcessorConstructor(BaseModel):
    qualname: str
177
178
    args: list[Any] | None = None
    kwargs: dict[str, Any] | None = None
179

180
181
    model_config = ConfigDict(extra="forbid")

182

183
LogitsProcessors = list[str | LogitsProcessorConstructor]
184
185


186
def get_logits_processors(
187
188
    processors: LogitsProcessors | None, pattern: str | None
) -> list[Any] | None:
189
190
191
    if processors and pattern:
        logits_processors = []
        for processor in processors:
192
            qualname = processor if isinstance(processor, str) else processor.qualname
193
194
195
196
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
197
198
                    "for more information."
                )
199
200
201
202
203
204
205
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
206
207
208
                logits_processor = logits_processor(
                    *processor.args or [], **processor.kwargs or {}
                )
209
210
211
212
213
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
214
            "server. See --logits-processor-pattern engine argument "
215
216
            "for more information."
        )
217
218
219
    return None


220
class FunctionCall(OpenAIBaseModel):
221
222
223
224
    # Internal field to preserve native tool call ID from tool parser.
    # Excluded from serialization to maintain OpenAI API compatibility
    # (function object should only contain 'name' and 'arguments').
    id: str | None = Field(default=None, exclude=True)
225
226
227
228
229
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
230
    id: str = Field(default_factory=make_tool_call_id)
231
232
233
234
    type: Literal["function"] = "function"
    function: FunctionCall


235
class DeltaFunctionCall(BaseModel):
236
237
    name: str | None = None
    arguments: str | None = None
238
239
240
241


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
242
243
    id: str | None = None
    type: Literal["function"] | None = None
244
    index: int
245
    function: DeltaFunctionCall | None = None
246
247
248
249
250
251
252


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
253
    tool_calls: list[ToolCall]
254
255
256

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
257
    content: str | None = None
258
259


260
class DeltaMessage(OpenAIBaseModel):
261
262
    role: str | None = None
    content: str | None = None
263
    reasoning: str | None = None
264
    reasoning_content: str | None = None
265
    """Deprecated: use `reasoning` instead."""
266
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
267

268
269
270
271
272
273
    @model_validator(mode="after")
    def handle_deprecated_reasoning_content(self):
        """Copy reasoning to reasoning_content for backward compatibility."""
        self.reasoning_content = self.reasoning
        return self

274

275
276
277
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
    request_id: str = Field(
278
        default_factory=random_uuid,
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    token_ids: list[int]
    """The token ids to generate text from."""

    # features: MultiModalFeatureSpec
    # TODO (NickLucche): implement once Renderer work is completed
    features: str | None = None
    """The processed MM inputs for the model."""

    sampling_params: SamplingParams
    """The sampling parameters for the model."""

    model: str | None = None

    stream: bool | None = False
    stream_options: StreamOptions | None = None
    cache_salt: str | None = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit)."
        ),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."
        ),
    )
    kv_transfer_params: dict[str, Any] | None = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.",
    )