protocol.py 11.4 KB
Newer Older
1
2
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
Zhuohan Li's avatar
Zhuohan Li committed
3
4
5
import time
from typing import Dict, List, Literal, Optional, Union

6
from pydantic import BaseModel, Field, model_validator
Zhuohan Li's avatar
Zhuohan Li committed
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.utils import random_uuid
9
from vllm.sampling_params import SamplingParams
Zhuohan Li's avatar
Zhuohan Li committed
10

11
12
import torch

Zhuohan Li's avatar
Zhuohan Li committed
13
14
15
16
17
18

class ErrorResponse(BaseModel):
    object: str = "error"
    message: str
    type: str
    param: Optional[str] = None
19
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40


class ModelPermission(BaseModel):
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
    group: Optional[str] = None
    is_blocking: str = False


class ModelCard(BaseModel):
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
41
    owned_by: str = "vllm"
Zhuohan Li's avatar
Zhuohan Li committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    root: Optional[str] = None
    parent: Optional[str] = None
    permission: List[ModelPermission] = Field(default_factory=list)


class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelCard] = Field(default_factory=list)


class UsageInfo(BaseModel):
    prompt_tokens: int = 0
    total_tokens: int = 0
    completion_tokens: Optional[int] = 0


class ChatCompletionRequest(BaseModel):
    model: str
60
    messages: List[Dict[str, str]]
Zhuohan Li's avatar
Zhuohan Li committed
61
62
63
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 1.0
    n: Optional[int] = 1
64
    max_tokens: Optional[int] = None
Nick Hill's avatar
Nick Hill committed
65
    seed: Optional[int] = None
66
    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
67
    stream: Optional[bool] = False
68
69
    logprobs: Optional[bool] = False
    top_logprobs: Optional[int] = None
Zhuohan Li's avatar
Zhuohan Li committed
70
71
    presence_penalty: Optional[float] = 0.0
    frequency_penalty: Optional[float] = 0.0
72
    logit_bias: Optional[Dict[str, float]] = None
Zhuohan Li's avatar
Zhuohan Li committed
73
    user: Optional[str] = None
74
75
76
77
78
    # Additional parameters supported by vLLM
    best_of: Optional[int] = None
    top_k: Optional[int] = -1
    ignore_eos: Optional[bool] = False
    use_beam_search: Optional[bool] = False
79
    early_stopping: Optional[bool] = False
80
    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
81
    skip_special_tokens: Optional[bool] = True
82
    spaces_between_special_tokens: Optional[bool] = True
83
84
    add_generation_prompt: Optional[bool] = True
    echo: Optional[bool] = False
Roy's avatar
Roy committed
85
86
    repetition_penalty: Optional[float] = 1.0
    min_p: Optional[float] = 0.0
87
88
    include_stop_str_in_output: Optional[bool] = False
    length_penalty: Optional[float] = 1.0
89
90
91
    guided_json: Optional[Union[str, dict, BaseModel]] = None
    guided_regex: Optional[str] = None
    guided_choice: Optional[List[str]] = None
Zhuohan Li's avatar
Zhuohan Li committed
92

93
    def to_sampling_params(self) -> SamplingParams:
94
95
        if self.logprobs and not self.top_logprobs:
            raise ValueError("Top logprobs must be set when logprobs is.")
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

        logits_processors = None
        if self.logit_bias:

            def logit_bias_logits_processor(
                    token_ids: List[int],
                    logits: torch.Tensor) -> torch.Tensor:
                for token_id, bias in self.logit_bias.items():
                    # Clamp the bias between -100 and 100 per OpenAI API spec
                    bias = min(100, max(-100, bias))
                    logits[int(token_id)] += bias
                return logits

            logits_processors = [logit_bias_logits_processor]

111
112
113
114
115
116
117
118
        return SamplingParams(
            n=self.n,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=self.repetition_penalty,
            temperature=self.temperature,
            top_p=self.top_p,
            min_p=self.min_p,
Nick Hill's avatar
Nick Hill committed
119
            seed=self.seed,
120
121
122
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
            max_tokens=self.max_tokens,
123
124
            logprobs=self.top_logprobs if self.logprobs else None,
            prompt_logprobs=self.top_logprobs if self.echo else None,
125
126
127
128
            best_of=self.best_of,
            top_k=self.top_k,
            ignore_eos=self.ignore_eos,
            use_beam_search=self.use_beam_search,
129
            early_stopping=self.early_stopping,
130
131
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
132
133
            include_stop_str_in_output=self.include_stop_str_in_output,
            length_penalty=self.length_penalty,
134
            logits_processors=logits_processors,
135
136
        )

137
138
139
140
141
142
143
144
145
146
147
148
149
150
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
        return data

Zhuohan Li's avatar
Zhuohan Li committed
151
152
153

class CompletionRequest(BaseModel):
    model: str
154
155
    # a string, array of strings, array of tokens, or array of token arrays
    prompt: Union[List[int], List[List[int]], str, List[str]]
Zhuohan Li's avatar
Zhuohan Li committed
156
157
158
159
160
161
162
163
164
    suffix: Optional[str] = None
    max_tokens: Optional[int] = 16
    temperature: Optional[float] = 1.0
    top_p: Optional[float] = 1.0
    n: Optional[int] = 1
    stream: Optional[bool] = False
    logprobs: Optional[int] = None
    echo: Optional[bool] = False
    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
Nick Hill's avatar
Nick Hill committed
165
    seed: Optional[int] = None
Zhuohan Li's avatar
Zhuohan Li committed
166
167
168
169
170
    presence_penalty: Optional[float] = 0.0
    frequency_penalty: Optional[float] = 0.0
    best_of: Optional[int] = None
    logit_bias: Optional[Dict[str, float]] = None
    user: Optional[str] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
171
    # Additional parameters supported by vLLM
Zhuohan Li's avatar
Zhuohan Li committed
172
173
174
    top_k: Optional[int] = -1
    ignore_eos: Optional[bool] = False
    use_beam_search: Optional[bool] = False
175
    early_stopping: Optional[bool] = False
176
    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
177
    skip_special_tokens: Optional[bool] = True
178
    spaces_between_special_tokens: Optional[bool] = True
Roy's avatar
Roy committed
179
180
    repetition_penalty: Optional[float] = 1.0
    min_p: Optional[float] = 0.0
181
182
    include_stop_str_in_output: Optional[bool] = False
    length_penalty: Optional[float] = 1.0
183
184
185
    guided_json: Optional[Union[str, dict, BaseModel]] = None
    guided_regex: Optional[str] = None
    guided_choice: Optional[List[str]] = None
Zhuohan Li's avatar
Zhuohan Li committed
186

187
188
189
    def to_sampling_params(self):
        echo_without_generation = self.echo and self.max_tokens == 0

190
191
192
193
194
195
196
197
198
199
200
201
202
203
        logits_processors = None
        if self.logit_bias:

            def logit_bias_logits_processor(
                    token_ids: List[int],
                    logits: torch.Tensor) -> torch.Tensor:
                for token_id, bias in self.logit_bias.items():
                    # Clamp the bias between -100 and 100 per OpenAI API spec
                    bias = min(100, max(-100, bias))
                    logits[int(token_id)] += bias
                return logits

            logits_processors = [logit_bias_logits_processor]

204
205
206
207
208
209
210
211
212
213
        return SamplingParams(
            n=self.n,
            best_of=self.best_of,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=self.repetition_penalty,
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
            min_p=self.min_p,
Nick Hill's avatar
Nick Hill committed
214
            seed=self.seed,
215
216
217
218
219
220
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
            ignore_eos=self.ignore_eos,
            max_tokens=self.max_tokens if not echo_without_generation else 1,
            logprobs=self.logprobs,
            use_beam_search=self.use_beam_search,
221
            early_stopping=self.early_stopping,
222
223
224
            prompt_logprobs=self.logprobs if self.echo else None,
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=(self.spaces_between_special_tokens),
225
226
            include_stop_str_in_output=self.include_stop_str_in_output,
            length_penalty=self.length_penalty,
227
            logits_processors=logits_processors,
228
229
        )

230
231
232
233
234
235
236
237
238
239
240
241
242
243
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
        return data

Zhuohan Li's avatar
Zhuohan Li committed
244
245
246
247
248

class LogProbs(BaseModel):
    text_offset: List[int] = Field(default_factory=list)
    token_logprobs: List[Optional[float]] = Field(default_factory=list)
    tokens: List[str] = Field(default_factory=list)
249
    top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
Zhuohan Li's avatar
Zhuohan Li committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280


class CompletionResponseChoice(BaseModel):
    index: int
    text: str
    logprobs: Optional[LogProbs] = None
    finish_reason: Optional[Literal["stop", "length"]] = None


class CompletionResponse(BaseModel):
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[CompletionResponseChoice]
    usage: UsageInfo


class CompletionResponseStreamChoice(BaseModel):
    index: int
    text: str
    logprobs: Optional[LogProbs] = None
    finish_reason: Optional[Literal["stop", "length"]] = None


class CompletionStreamResponse(BaseModel):
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[CompletionResponseStreamChoice]
281
    usage: Optional[UsageInfo] = Field(default=None)
282
283
284
285
286
287
288
289
290
291


class ChatMessage(BaseModel):
    role: str
    content: str


class ChatCompletionResponseChoice(BaseModel):
    index: int
    message: ChatMessage
292
    logprobs: Optional[LogProbs] = None
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    finish_reason: Optional[Literal["stop", "length"]] = None


class ChatCompletionResponse(BaseModel):
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
    object: str = "chat.completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionResponseChoice]
    usage: UsageInfo


class DeltaMessage(BaseModel):
    role: Optional[str] = None
    content: Optional[str] = None


class ChatCompletionResponseStreamChoice(BaseModel):
    index: int
    delta: DeltaMessage
313
    logprobs: Optional[LogProbs] = None
314
315
316
317
318
319
320
321
322
    finish_reason: Optional[Literal["stop", "length"]] = None


class ChatCompletionStreamResponse(BaseModel):
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
    object: str = "chat.completion.chunk"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionResponseStreamChoice]
323
    usage: Optional[UsageInfo] = Field(default=None)