protocol.py 11.7 KB
Newer Older
1
2
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
Zhuohan Li's avatar
Zhuohan Li committed
3
4
5
import time
from typing import Dict, List, Literal, Optional, Union

6
from pydantic import BaseModel, Field, model_validator
Zhuohan Li's avatar
Zhuohan Li committed
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.utils import random_uuid
9
from vllm.sampling_params import SamplingParams
Zhuohan Li's avatar
Zhuohan Li committed
10

11
12
import torch

Zhuohan Li's avatar
Zhuohan Li committed
13
14
15
16
17
18

class ErrorResponse(BaseModel):
    object: str = "error"
    message: str
    type: str
    param: Optional[str] = None
19
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40


class ModelPermission(BaseModel):
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
    group: Optional[str] = None
    is_blocking: str = False


class ModelCard(BaseModel):
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
41
    owned_by: str = "vllm"
Zhuohan Li's avatar
Zhuohan Li committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    root: Optional[str] = None
    parent: Optional[str] = None
    permission: List[ModelPermission] = Field(default_factory=list)


class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelCard] = Field(default_factory=list)


class UsageInfo(BaseModel):
    prompt_tokens: int = 0
    total_tokens: int = 0
    completion_tokens: Optional[int] = 0


58
59
60
61
62
class ResponseFormat(BaseModel):
    # type must be "json_object" or "text"
    type: str = Literal["text", "json_object"]


Zhuohan Li's avatar
Zhuohan Li committed
63
64
class ChatCompletionRequest(BaseModel):
    model: str
65
    messages: List[Dict[str, str]]
Zhuohan Li's avatar
Zhuohan Li committed
66
67
68
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 1.0
    n: Optional[int] = 1
69
    max_tokens: Optional[int] = None
Nick Hill's avatar
Nick Hill committed
70
    seed: Optional[int] = None
71
    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
72
    stream: Optional[bool] = False
73
74
    logprobs: Optional[bool] = False
    top_logprobs: Optional[int] = None
Zhuohan Li's avatar
Zhuohan Li committed
75
76
    presence_penalty: Optional[float] = 0.0
    frequency_penalty: Optional[float] = 0.0
77
    logit_bias: Optional[Dict[str, float]] = None
Zhuohan Li's avatar
Zhuohan Li committed
78
    user: Optional[str] = None
79
80
81
82
83
    # Additional parameters supported by vLLM
    best_of: Optional[int] = None
    top_k: Optional[int] = -1
    ignore_eos: Optional[bool] = False
    use_beam_search: Optional[bool] = False
84
    early_stopping: Optional[bool] = False
85
    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
86
    skip_special_tokens: Optional[bool] = True
87
    spaces_between_special_tokens: Optional[bool] = True
88
89
    add_generation_prompt: Optional[bool] = True
    echo: Optional[bool] = False
Roy's avatar
Roy committed
90
91
    repetition_penalty: Optional[float] = 1.0
    min_p: Optional[float] = 0.0
92
93
    include_stop_str_in_output: Optional[bool] = False
    length_penalty: Optional[float] = 1.0
94
95
96
    guided_json: Optional[Union[str, dict, BaseModel]] = None
    guided_regex: Optional[str] = None
    guided_choice: Optional[List[str]] = None
97
98
    guided_grammar: Optional[str] = None
    response_format: Optional[ResponseFormat] = None
Zhuohan Li's avatar
Zhuohan Li committed
99

100
    def to_sampling_params(self) -> SamplingParams:
101
102
        if self.logprobs and not self.top_logprobs:
            raise ValueError("Top logprobs must be set when logprobs is.")
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

        logits_processors = None
        if self.logit_bias:

            def logit_bias_logits_processor(
                    token_ids: List[int],
                    logits: torch.Tensor) -> torch.Tensor:
                for token_id, bias in self.logit_bias.items():
                    # Clamp the bias between -100 and 100 per OpenAI API spec
                    bias = min(100, max(-100, bias))
                    logits[int(token_id)] += bias
                return logits

            logits_processors = [logit_bias_logits_processor]

118
119
120
121
122
123
124
125
        return SamplingParams(
            n=self.n,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=self.repetition_penalty,
            temperature=self.temperature,
            top_p=self.top_p,
            min_p=self.min_p,
Nick Hill's avatar
Nick Hill committed
126
            seed=self.seed,
127
128
129
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
            max_tokens=self.max_tokens,
130
131
            logprobs=self.top_logprobs if self.logprobs else None,
            prompt_logprobs=self.top_logprobs if self.echo else None,
132
133
134
135
            best_of=self.best_of,
            top_k=self.top_k,
            ignore_eos=self.ignore_eos,
            use_beam_search=self.use_beam_search,
136
            early_stopping=self.early_stopping,
137
138
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
139
140
            include_stop_str_in_output=self.include_stop_str_in_output,
            length_penalty=self.length_penalty,
141
            logits_processors=logits_processors,
142
143
        )

144
145
146
147
148
149
150
151
152
153
154
155
156
157
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
        return data

Zhuohan Li's avatar
Zhuohan Li committed
158
159
160

class CompletionRequest(BaseModel):
    model: str
161
162
    # a string, array of strings, array of tokens, or array of token arrays
    prompt: Union[List[int], List[List[int]], str, List[str]]
Zhuohan Li's avatar
Zhuohan Li committed
163
164
165
166
167
168
169
170
171
    suffix: Optional[str] = None
    max_tokens: Optional[int] = 16
    temperature: Optional[float] = 1.0
    top_p: Optional[float] = 1.0
    n: Optional[int] = 1
    stream: Optional[bool] = False
    logprobs: Optional[int] = None
    echo: Optional[bool] = False
    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
Nick Hill's avatar
Nick Hill committed
172
    seed: Optional[int] = None
Zhuohan Li's avatar
Zhuohan Li committed
173
174
175
176
177
    presence_penalty: Optional[float] = 0.0
    frequency_penalty: Optional[float] = 0.0
    best_of: Optional[int] = None
    logit_bias: Optional[Dict[str, float]] = None
    user: Optional[str] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
178
    # Additional parameters supported by vLLM
Zhuohan Li's avatar
Zhuohan Li committed
179
180
181
    top_k: Optional[int] = -1
    ignore_eos: Optional[bool] = False
    use_beam_search: Optional[bool] = False
182
    early_stopping: Optional[bool] = False
183
    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
184
    skip_special_tokens: Optional[bool] = True
185
    spaces_between_special_tokens: Optional[bool] = True
Roy's avatar
Roy committed
186
187
    repetition_penalty: Optional[float] = 1.0
    min_p: Optional[float] = 0.0
188
189
    include_stop_str_in_output: Optional[bool] = False
    length_penalty: Optional[float] = 1.0
190
191
192
    guided_json: Optional[Union[str, dict, BaseModel]] = None
    guided_regex: Optional[str] = None
    guided_choice: Optional[List[str]] = None
193
194
    guided_grammar: Optional[str] = None
    response_format: Optional[ResponseFormat] = None
Zhuohan Li's avatar
Zhuohan Li committed
195

196
197
198
    def to_sampling_params(self):
        echo_without_generation = self.echo and self.max_tokens == 0

199
200
201
202
203
204
205
206
207
208
209
210
211
212
        logits_processors = None
        if self.logit_bias:

            def logit_bias_logits_processor(
                    token_ids: List[int],
                    logits: torch.Tensor) -> torch.Tensor:
                for token_id, bias in self.logit_bias.items():
                    # Clamp the bias between -100 and 100 per OpenAI API spec
                    bias = min(100, max(-100, bias))
                    logits[int(token_id)] += bias
                return logits

            logits_processors = [logit_bias_logits_processor]

213
214
215
216
217
218
219
220
221
222
        return SamplingParams(
            n=self.n,
            best_of=self.best_of,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=self.repetition_penalty,
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
            min_p=self.min_p,
Nick Hill's avatar
Nick Hill committed
223
            seed=self.seed,
224
225
226
227
228
229
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
            ignore_eos=self.ignore_eos,
            max_tokens=self.max_tokens if not echo_without_generation else 1,
            logprobs=self.logprobs,
            use_beam_search=self.use_beam_search,
230
            early_stopping=self.early_stopping,
231
232
233
            prompt_logprobs=self.logprobs if self.echo else None,
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=(self.spaces_between_special_tokens),
234
235
            include_stop_str_in_output=self.include_stop_str_in_output,
            length_penalty=self.length_penalty,
236
            logits_processors=logits_processors,
237
238
        )

239
240
241
242
243
244
245
246
247
248
249
250
251
252
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
        return data

Zhuohan Li's avatar
Zhuohan Li committed
253
254
255
256
257

class LogProbs(BaseModel):
    text_offset: List[int] = Field(default_factory=list)
    token_logprobs: List[Optional[float]] = Field(default_factory=list)
    tokens: List[str] = Field(default_factory=list)
258
    top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
Zhuohan Li's avatar
Zhuohan Li committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289


class CompletionResponseChoice(BaseModel):
    index: int
    text: str
    logprobs: Optional[LogProbs] = None
    finish_reason: Optional[Literal["stop", "length"]] = None


class CompletionResponse(BaseModel):
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[CompletionResponseChoice]
    usage: UsageInfo


class CompletionResponseStreamChoice(BaseModel):
    index: int
    text: str
    logprobs: Optional[LogProbs] = None
    finish_reason: Optional[Literal["stop", "length"]] = None


class CompletionStreamResponse(BaseModel):
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[CompletionResponseStreamChoice]
290
    usage: Optional[UsageInfo] = Field(default=None)
291
292
293
294
295
296
297
298
299
300


class ChatMessage(BaseModel):
    role: str
    content: str


class ChatCompletionResponseChoice(BaseModel):
    index: int
    message: ChatMessage
301
    logprobs: Optional[LogProbs] = None
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    finish_reason: Optional[Literal["stop", "length"]] = None


class ChatCompletionResponse(BaseModel):
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
    object: str = "chat.completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionResponseChoice]
    usage: UsageInfo


class DeltaMessage(BaseModel):
    role: Optional[str] = None
    content: Optional[str] = None


class ChatCompletionResponseStreamChoice(BaseModel):
    index: int
    delta: DeltaMessage
322
    logprobs: Optional[LogProbs] = None
323
324
325
326
327
328
329
330
331
    finish_reason: Optional[Literal["stop", "length"]] = None


class ChatCompletionStreamResponse(BaseModel):
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
    object: str = "chat.completion.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionResponseStreamChoice]
332
    usage: Optional[UsageInfo] = Field(default=None)