protocol.py 16.7 KB
Newer Older
1
2
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
Zhuohan Li's avatar
Zhuohan Li committed
3
4
5
import time
from typing import Dict, List, Literal, Optional, Union

6
import torch
7
8
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
9
from typing_extensions import Annotated
Zhuohan Li's avatar
Zhuohan Li committed
10

11
from vllm.sampling_params import SamplingParams
12
from vllm.utils import random_uuid
13

Zhuohan Li's avatar
Zhuohan Li committed
14

15
16
17
18
19
20
class OpenAIBaseModel(BaseModel):
    # OpenAI API does not allow extra fields
    model_config = ConfigDict(extra="forbid")


class ErrorResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
21
22
23
24
    object: str = "error"
    message: str
    type: str
    param: Optional[str] = None
25
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
26
27


28
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
29
30
31
32
33
34
35
36
37
38
39
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
    group: Optional[str] = None
40
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
41
42


43
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
44
45
46
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
47
    owned_by: str = "vllm"
Zhuohan Li's avatar
Zhuohan Li committed
48
49
50
51
52
    root: Optional[str] = None
    parent: Optional[str] = None
    permission: List[ModelPermission] = Field(default_factory=list)


53
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
54
55
56
57
    object: str = "list"
    data: List[ModelCard] = Field(default_factory=list)


58
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
59
60
61
62
63
    prompt_tokens: int = 0
    total_tokens: int = 0
    completion_tokens: Optional[int] = 0


64
class ResponseFormat(OpenAIBaseModel):
65
    # type must be "json_object" or "text"
66
    type: Literal["text", "json_object"]
67
68


69
class ChatCompletionRequest(OpenAIBaseModel):
70
71
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
72
    messages: List[ChatCompletionMessageParam]
73
74
75
76
77
    model: str
    frequency_penalty: Optional[float] = 0.0
    logit_bias: Optional[Dict[str, float]] = None
    logprobs: Optional[bool] = False
    top_logprobs: Optional[int] = None
78
    max_tokens: Optional[int] = None
79
80
81
    n: Optional[int] = 1
    presence_penalty: Optional[float] = 0.0
    response_format: Optional[ResponseFormat] = None
82
83
84
    seed: Optional[int] = Field(None,
                                ge=torch.iinfo(torch.long).min,
                                le=torch.iinfo(torch.long).max)
85
    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
86
    stream: Optional[bool] = False
87
88
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 1.0
Zhuohan Li's avatar
Zhuohan Li committed
89
    user: Optional[str] = None
90
91

    # doc: begin-chat-completion-sampling-params
92
93
    best_of: Optional[int] = None
    use_beam_search: Optional[bool] = False
94
95
96
97
    top_k: Optional[int] = -1
    min_p: Optional[float] = 0.0
    repetition_penalty: Optional[float] = 1.0
    length_penalty: Optional[float] = 1.0
98
    early_stopping: Optional[bool] = False
99
    ignore_eos: Optional[bool] = False
100
    min_tokens: Optional[int] = 0
101
    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
102
    skip_special_tokens: Optional[bool] = True
103
    spaces_between_special_tokens: Optional[bool] = True
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    # doc: end-chat-completion-sampling-params

    # doc: begin-chat-completion-extra-params
    echo: Optional[bool] = Field(
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
            "if they belong to the same role."),
    )
    add_generation_prompt: Optional[bool] = Field(
        default=True,
        description=
        ("If true, the generation prompt will be added to the chat template. "
         "This is a parameter used by chat template in tokenizer config of the "
         "model."),
    )
    include_stop_str_in_output: Optional[bool] = Field(
        default=False,
        description=(
            "Whether to include the stop string in the output. "
            "This is only applied when the stop or stop_token_ids is set."),
    )
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
        description=("If specified, the output will follow the JSON schema."),
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the regex pattern."),
    )
    guided_choice: Optional[List[str]] = Field(
        default=None,
        description=(
            "If specified, the output will be exactly one of the choices."),
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the context free grammar."),
    )
145
146
147
148
149
150
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default guided decoding backend "
            "of the server for this specific request. If set, must be either "
            "'outlines' / 'lm-format-enforcer'"))
151
152
153
154
155
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default whitespace pattern "
            "for guided json decoding."))
156
157

    # doc: end-chat-completion-extra-params
Zhuohan Li's avatar
Zhuohan Li committed
158

159
    def to_sampling_params(self) -> SamplingParams:
160
161
        if self.logprobs and not self.top_logprobs:
            raise ValueError("Top logprobs must be set when logprobs is.")
162
163
164
165
166
167
168

        logits_processors = None
        if self.logit_bias:

            def logit_bias_logits_processor(
                    token_ids: List[int],
                    logits: torch.Tensor) -> torch.Tensor:
169
                assert self.logit_bias is not None
170
171
172
173
174
175
176
177
                for token_id, bias in self.logit_bias.items():
                    # Clamp the bias between -100 and 100 per OpenAI API spec
                    bias = min(100, max(-100, bias))
                    logits[int(token_id)] += bias
                return logits

            logits_processors = [logit_bias_logits_processor]

178
179
180
181
182
183
184
185
        return SamplingParams(
            n=self.n,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=self.repetition_penalty,
            temperature=self.temperature,
            top_p=self.top_p,
            min_p=self.min_p,
Nick Hill's avatar
Nick Hill committed
186
            seed=self.seed,
187
188
189
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
            max_tokens=self.max_tokens,
190
            min_tokens=self.min_tokens,
191
192
            logprobs=self.top_logprobs if self.logprobs else None,
            prompt_logprobs=self.top_logprobs if self.echo else None,
193
194
195
196
            best_of=self.best_of,
            top_k=self.top_k,
            ignore_eos=self.ignore_eos,
            use_beam_search=self.use_beam_search,
197
            early_stopping=self.early_stopping,
198
199
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
200
201
            include_stop_str_in_output=self.include_stop_str_in_output,
            length_penalty=self.length_penalty,
202
            logits_processors=logits_processors,
203
204
        )

205
206
207
208
209
210
211
212
213
214
215
216
217
218
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
        return data

Zhuohan Li's avatar
Zhuohan Li committed
219

220
class CompletionRequest(OpenAIBaseModel):
221
222
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
Zhuohan Li's avatar
Zhuohan Li committed
223
    model: str
224
    prompt: Union[List[int], List[List[int]], str, List[str]]
225
    best_of: Optional[int] = None
Zhuohan Li's avatar
Zhuohan Li committed
226
227
228
    echo: Optional[bool] = False
    frequency_penalty: Optional[float] = 0.0
    logit_bias: Optional[Dict[str, float]] = None
229
230
    logprobs: Optional[int] = None
    max_tokens: Optional[int] = 16
231
    n: int = 1
232
    presence_penalty: Optional[float] = 0.0
233
234
235
    seed: Optional[int] = Field(None,
                                ge=torch.iinfo(torch.long).min,
                                le=torch.iinfo(torch.long).max)
236
237
238
239
240
    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
    stream: Optional[bool] = False
    suffix: Optional[str] = None
    temperature: Optional[float] = 1.0
    top_p: Optional[float] = 1.0
Zhuohan Li's avatar
Zhuohan Li committed
241
    user: Optional[str] = None
242
243

    # doc: begin-completion-sampling-params
Zhuohan Li's avatar
Zhuohan Li committed
244
    use_beam_search: Optional[bool] = False
245
246
247
248
    top_k: Optional[int] = -1
    min_p: Optional[float] = 0.0
    repetition_penalty: Optional[float] = 1.0
    length_penalty: Optional[float] = 1.0
249
    early_stopping: Optional[bool] = False
250
    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
251
    ignore_eos: Optional[bool] = False
252
    min_tokens: Optional[int] = 0
253
    skip_special_tokens: Optional[bool] = True
254
    spaces_between_special_tokens: Optional[bool] = True
255
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    # doc: end-completion-sampling-params

    # doc: begin-completion-extra-params
    include_stop_str_in_output: Optional[bool] = Field(
        default=False,
        description=(
            "Whether to include the stop string in the output. "
            "This is only applied when the stop or stop_token_ids is set."),
    )
    response_format: Optional[ResponseFormat] = Field(
        default=None,
        description=
        ("Similar to chat completion, this parameter specifies the format of "
         "output. Only {'type': 'json_object'} or {'type': 'text' } is "
         "supported."),
    )
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
        description=("If specified, the output will follow the JSON schema."),
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the regex pattern."),
    )
    guided_choice: Optional[List[str]] = Field(
        default=None,
        description=(
            "If specified, the output will be exactly one of the choices."),
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the context free grammar."),
    )
291
292
293
294
295
296
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default guided decoding backend "
            "of the server for this specific request. If set, must be one of "
            "'outlines' / 'lm-format-enforcer'"))
297
298
299
300
301
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default whitespace pattern "
            "for guided json decoding."))
302
303

    # doc: end-completion-extra-params
Zhuohan Li's avatar
Zhuohan Li committed
304

305
306
307
    def to_sampling_params(self):
        echo_without_generation = self.echo and self.max_tokens == 0

308
309
310
311
312
313
        logits_processors = None
        if self.logit_bias:

            def logit_bias_logits_processor(
                    token_ids: List[int],
                    logits: torch.Tensor) -> torch.Tensor:
314
                assert self.logit_bias is not None
315
316
317
318
319
320
321
322
                for token_id, bias in self.logit_bias.items():
                    # Clamp the bias between -100 and 100 per OpenAI API spec
                    bias = min(100, max(-100, bias))
                    logits[int(token_id)] += bias
                return logits

            logits_processors = [logit_bias_logits_processor]

323
324
325
326
327
328
329
330
331
332
        return SamplingParams(
            n=self.n,
            best_of=self.best_of,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=self.repetition_penalty,
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
            min_p=self.min_p,
Nick Hill's avatar
Nick Hill committed
333
            seed=self.seed,
334
335
336
337
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
            ignore_eos=self.ignore_eos,
            max_tokens=self.max_tokens if not echo_without_generation else 1,
338
            min_tokens=self.min_tokens,
339
340
            logprobs=self.logprobs,
            use_beam_search=self.use_beam_search,
341
            early_stopping=self.early_stopping,
342
343
344
            prompt_logprobs=self.logprobs if self.echo else None,
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=(self.spaces_between_special_tokens),
345
346
            include_stop_str_in_output=self.include_stop_str_in_output,
            length_penalty=self.length_penalty,
347
            logits_processors=logits_processors,
348
            truncate_prompt_tokens=self.truncate_prompt_tokens,
349
350
        )

351
352
353
354
355
356
357
358
359
360
361
362
363
364
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
        return data

Zhuohan Li's avatar
Zhuohan Li committed
365

366
class LogProbs(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
367
368
369
    text_offset: List[int] = Field(default_factory=list)
    token_logprobs: List[Optional[float]] = Field(default_factory=list)
    tokens: List[str] = Field(default_factory=list)
370
    top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
Zhuohan Li's avatar
Zhuohan Li committed
371
372


373
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
374
375
376
    index: int
    text: str
    logprobs: Optional[LogProbs] = None
377
378
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
379
380
381
382
383
384
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
            "including encountering the EOS token"),
    )
Zhuohan Li's avatar
Zhuohan Li committed
385
386


387
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
388
389
390
391
392
393
394
395
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[CompletionResponseChoice]
    usage: UsageInfo


396
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
397
398
399
    index: int
    text: str
    logprobs: Optional[LogProbs] = None
400
401
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
402
403
404
405
406
407
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
            "including encountering the EOS token"),
    )
Zhuohan Li's avatar
Zhuohan Li committed
408
409


410
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
411
412
413
414
415
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[CompletionResponseStreamChoice]
416
    usage: Optional[UsageInfo] = Field(default=None)
417
418


419
class ChatMessage(OpenAIBaseModel):
420
421
422
423
    role: str
    content: str


424
class ChatCompletionResponseChoice(OpenAIBaseModel):
425
426
    index: int
    message: ChatMessage
427
    logprobs: Optional[LogProbs] = None
428
429
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = None
430
431


432
class ChatCompletionResponse(OpenAIBaseModel):
433
434
435
436
437
438
439
440
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
    object: str = "chat.completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionResponseChoice]
    usage: UsageInfo


441
class DeltaMessage(OpenAIBaseModel):
442
443
444
445
    role: Optional[str] = None
    content: Optional[str] = None


446
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
447
448
    index: int
    delta: DeltaMessage
449
    logprobs: Optional[LogProbs] = None
450
451
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = None
452
453


454
class ChatCompletionStreamResponse(OpenAIBaseModel):
455
456
457
458
459
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
    object: str = "chat.completion.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionResponseStreamChoice]
460
    usage: Optional[UsageInfo] = Field(default=None)