protocol.py 9.72 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
Zhuohan Li's avatar
Zhuohan Li committed
6
import time
7
from typing import Any, ClassVar, Literal, TypeAlias
Zhuohan Li's avatar
Zhuohan Li committed
8

9
import regex as re
10
import torch
11
12
13
14
15
16
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    model_validator,
)
Zhuohan Li's avatar
Zhuohan Li committed
17

18
from vllm.entrypoints.chat_utils import make_tool_call_id
19
from vllm.logger import init_logger
20
21
22
from vllm.sampling_params import (
    SamplingParams,
)
23
24
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
25

26
27
logger = init_logger(__name__)

28
_LONG_INFO = torch.iinfo(torch.long)
29

Zhuohan Li's avatar
Zhuohan Li committed
30

31
class OpenAIBaseModel(BaseModel):
32
33
34
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

35
    # Cache class field names
36
    field_names: ClassVar[set[str] | None] = None
37

38
    @model_validator(mode="wrap")
39
    @classmethod
40
41
42
43
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
44
45
        field_names = cls.field_names
        if field_names is None:
46
47
48
49
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
50
                if alias := getattr(field, "alias", None):
51
52
53
54
55
56
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
57
                "The following fields were present in the request but ignored: %s",
58
59
                data.keys() - field_names,
            )
60
        return result
61
62


63
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
64
65
    message: str
    type: str
66
    param: str | None = None
67
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
68
69


70
71
72
73
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


74
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
75
76
77
78
79
80
81
82
83
84
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
85
    group: str | None = None
86
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
87
88


89
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
90
91
92
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
93
    owned_by: str = "vllm"
94
95
96
    root: str | None = None
    parent: str | None = None
    max_model_len: int | None = None
97
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
98
99


100
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
101
    object: str = "list"
102
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
103
104


105
class PromptTokenUsageInfo(OpenAIBaseModel):
106
    cached_tokens: int | None = None
107
108


109
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
110
111
    prompt_tokens: int = 0
    total_tokens: int = 0
112
113
    completion_tokens: int | None = 0
    prompt_tokens_details: PromptTokenUsageInfo | None = None
Zhuohan Li's avatar
Zhuohan Li committed
114
115


116
117
class RequestResponseMetadata(BaseModel):
    request_id: str
118
    final_usage_info: UsageInfo | None = None
119
120


121
122
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
123
    description: str | None = None
124
125
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
126
127
    json_schema: dict[str, Any] | None = Field(default=None, alias="schema")
    strict: bool | None = None
128
129


130
class LegacyStructuralTag(OpenAIBaseModel):
131
132
133
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
134
    structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema")
135
136
137
    end: str


138
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
139
    type: Literal["structural_tag"]
140
    structures: list[LegacyStructuralTag]
141
142
143
    triggers: list[str]


144
145
146
147
148
149
150
151
152
153
class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    format: Any


AnyStructuralTagResponseFormat: TypeAlias = (
    LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
)


154
class ResponseFormat(OpenAIBaseModel):
155
    # type must be "json_schema", "json_object", or "text"
156
    type: Literal["text", "json_object", "json_schema"]
157
    json_schema: JsonSchemaResponseFormat | None = None
158
159


160
161
162
AnyResponseFormat: TypeAlias = (
    ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
)
163
164


165
class StreamOptions(OpenAIBaseModel):
166
167
    include_usage: bool | None = True
    continuous_usage_stats: bool | None = False
168
169


170
171
class FunctionDefinition(OpenAIBaseModel):
    name: str
172
173
    description: str | None = None
    parameters: dict[str, Any] | None = None
174
175


176
177
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
178
179
class LogitsProcessorConstructor(BaseModel):
    qualname: str
180
181
    args: list[Any] | None = None
    kwargs: dict[str, Any] | None = None
182

183
184
    model_config = ConfigDict(extra="forbid")

185

186
LogitsProcessors = list[str | LogitsProcessorConstructor]
187
188


189
def get_logits_processors(
190
191
    processors: LogitsProcessors | None, pattern: str | None
) -> list[Any] | None:
192
193
194
    if processors and pattern:
        logits_processors = []
        for processor in processors:
195
            qualname = processor if isinstance(processor, str) else processor.qualname
196
197
198
199
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
200
201
                    "for more information."
                )
202
203
204
205
206
207
208
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
209
210
211
                logits_processor = logits_processor(
                    *processor.args or [], **processor.kwargs or {}
                )
212
213
214
215
216
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
217
            "server. See --logits-processor-pattern engine argument "
218
219
            "for more information."
        )
220
221
222
    return None


223
224
225
226
227
228
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
229
    id: str = Field(default_factory=make_tool_call_id)
230
231
232
233
    type: Literal["function"] = "function"
    function: FunctionCall


234
class DeltaFunctionCall(BaseModel):
235
236
    name: str | None = None
    arguments: str | None = None
237
238
239
240


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
241
242
    id: str | None = None
    type: Literal["function"] | None = None
243
    index: int
244
    function: DeltaFunctionCall | None = None
245
246
247
248
249
250
251


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
252
    tool_calls: list[ToolCall]
253
254
255

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
256
    content: str | None = None
257
258


259
class DeltaMessage(OpenAIBaseModel):
260
261
    role: str | None = None
    content: str | None = None
262
    reasoning: str | None = None
263
    reasoning_content: str | None = None
264
    """Deprecated: use `reasoning` instead."""
265
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
266

267
268
269
270
271
272
    @model_validator(mode="after")
    def handle_deprecated_reasoning_content(self):
        """Copy reasoning to reasoning_content for backward compatibility."""
        self.reasoning_content = self.reasoning
        return self

273

274
275
276
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
    request_id: str = Field(
277
        default_factory=random_uuid,
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    token_ids: list[int]
    """The token ids to generate text from."""

    # features: MultiModalFeatureSpec
    # TODO (NickLucche): implement once Renderer work is completed
    features: str | None = None
    """The processed MM inputs for the model."""

    sampling_params: SamplingParams
    """The sampling parameters for the model."""

    model: str | None = None

    stream: bool | None = False
    stream_options: StreamOptions | None = None
    cache_salt: str | None = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit)."
        ),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."
        ),
    )
    kv_transfer_params: dict[str, Any] | None = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.",
    )