protocol.py 40.7 KB
Newer Older
1
2
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
Zhuohan Li's avatar
Zhuohan Li committed
3
import time
4
from argparse import Namespace
5
from typing import Any, Dict, List, Literal, Optional, Union
Zhuohan Li's avatar
Zhuohan Li committed
6

7
import torch
8
from pydantic import BaseModel, ConfigDict, Field, model_validator
9
from typing_extensions import Annotated
Zhuohan Li's avatar
Zhuohan Li committed
10

11
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
12
from vllm.pooling_params import PoolingParams
13
14
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
15
from vllm.sequence import Logprob
16
from vllm.utils import random_uuid
17

18
19
20
# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
21
_LONG_INFO: Union["torch.iinfo", Namespace]
22
23
24
25
26
27
28
29
30
31
32
33
34
35

try:
    from sphinx.ext.autodoc.mock import _MockModule

    if isinstance(torch, _MockModule):
        _LONG_INFO = _MOCK_LONG_INFO
    else:
        _LONG_INFO = torch.iinfo(torch.long)
except ModuleNotFoundError:
    _LONG_INFO = torch.iinfo(torch.long)

assert _LONG_INFO.min == _MOCK_LONG_INFO.min
assert _LONG_INFO.max == _MOCK_LONG_INFO.max

Zhuohan Li's avatar
Zhuohan Li committed
36

37
38
39
40
41
42
class OpenAIBaseModel(BaseModel):
    # OpenAI API does not allow extra fields
    model_config = ConfigDict(extra="forbid")


class ErrorResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
43
44
45
46
    object: str = "error"
    message: str
    type: str
    param: Optional[str] = None
47
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
48
49


50
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
51
52
53
54
55
56
57
58
59
60
61
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
    group: Optional[str] = None
62
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
63
64


65
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
66
67
68
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
69
    owned_by: str = "vllm"
Zhuohan Li's avatar
Zhuohan Li committed
70
71
    root: Optional[str] = None
    parent: Optional[str] = None
72
    max_model_len: Optional[int] = None
Zhuohan Li's avatar
Zhuohan Li committed
73
74
75
    permission: List[ModelPermission] = Field(default_factory=list)


76
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
77
78
79
80
    object: str = "list"
    data: List[ModelCard] = Field(default_factory=list)


81
82
83
84
class PromptTokenUsageInfo(OpenAIBaseModel):
    cached_tokens: Optional[int] = None


85
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
86
87
88
    prompt_tokens: int = 0
    total_tokens: int = 0
    completion_tokens: Optional[int] = 0
89
    prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
Zhuohan Li's avatar
Zhuohan Li committed
90
91


92
93
94
95
96
class RequestResponseMetadata(BaseModel):
    request_id: str
    final_usage_info: Optional[UsageInfo] = None


97
98
99
100
101
102
103
104
105
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
    description: Optional[str] = None
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
    json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
    strict: Optional[bool] = None


106
class ResponseFormat(OpenAIBaseModel):
107
108
109
    # type must be "json_schema", "json_object" or "text"
    type: Literal["text", "json_object", "json_schema"]
    json_schema: Optional[JsonSchemaResponseFormat] = None
110
111


112
class StreamOptions(OpenAIBaseModel):
113
    include_usage: Optional[bool] = True
114
    continuous_usage_stats: Optional[bool] = False
115
116


117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class FunctionDefinition(OpenAIBaseModel):
    name: str
    description: Optional[str] = None
    parameters: Optional[Dict[str, Any]] = None


class ChatCompletionToolsParam(OpenAIBaseModel):
    type: Literal["function"] = "function"
    function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
    name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
    function: ChatCompletionNamedFunction
    type: Literal["function"] = "function"


137
class ChatCompletionRequest(OpenAIBaseModel):
138
139
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
140
    messages: List[ChatCompletionMessageParam]
141
142
143
144
    model: str
    frequency_penalty: Optional[float] = 0.0
    logit_bias: Optional[Dict[str, float]] = None
    logprobs: Optional[bool] = False
145
    top_logprobs: Optional[int] = 0
146
147
148
149
150
151
    # TODO(#9845): remove max_tokens when field is removed from OpenAI API
    max_tokens: Optional[int] = Field(
        default=None,
        deprecated=
        'max_tokens is deprecated in favor of the max_completion_tokens field')
    max_completion_tokens: Optional[int] = None
152
153
154
    n: Optional[int] = 1
    presence_penalty: Optional[float] = 0.0
    response_format: Optional[ResponseFormat] = None
155
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
156
    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
157
    stream: Optional[bool] = False
158
    stream_options: Optional[StreamOptions] = None
159
160
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 1.0
161
    tools: Optional[List[ChatCompletionToolsParam]] = None
162
    tool_choice: Optional[Union[Literal["none"], Literal["auto"],
163
                                ChatCompletionNamedToolChoiceParam]] = "none"
164
165
166

    # NOTE this will be ignored by VLLM -- the model determines the behavior
    parallel_tool_calls: Optional[bool] = False
Zhuohan Li's avatar
Zhuohan Li committed
167
    user: Optional[str] = None
168
169

    # doc: begin-chat-completion-sampling-params
170
    best_of: Optional[int] = None
171
172
173
174
175
    use_beam_search: bool = False
    top_k: int = -1
    min_p: float = 0.0
    repetition_penalty: float = 1.0
    length_penalty: float = 1.0
176
    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
177
178
179
180
181
182
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
183
    prompt_logprobs: Optional[int] = None
184
185
186
    # doc: end-chat-completion-sampling-params

    # doc: begin-chat-completion-extra-params
187
    echo: bool = Field(
188
189
190
191
192
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
            "if they belong to the same role."),
    )
193
    add_generation_prompt: bool = Field(
194
195
196
197
198
199
        default=True,
        description=
        ("If true, the generation prompt will be added to the chat template. "
         "This is a parameter used by chat template in tokenizer config of the "
         "model."),
    )
200
201
202
203
204
205
206
207
208
    continue_final_message: bool = Field(
        default=False,
        description=
        ("If this is set, the chat will be formatted so that the final "
         "message in the chat is open-ended, without any EOS tokens. The "
         "model will continue this message rather than starting a new one. "
         "This allows you to \"prefill\" part of the model's response for it. "
         "Cannot be used at the same time as `add_generation_prompt`."),
    )
209
    add_special_tokens: bool = Field(
210
211
212
213
214
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
215
            "special tokens so this should be set to false (as is the "
216
217
            "default)."),
    )
218
219
220
221
222
223
224
225
226
227
228
229
230
    documents: Optional[List[Dict[str, str]]] = Field(
        default=None,
        description=
        ("A list of dicts representing documents that will be accessible to "
         "the model if it is performing RAG (retrieval-augmented generation)."
         " If the template does not support RAG, this argument will have no "
         "effect. We recommend that each document should be a dict containing "
         "\"title\" and \"text\" keys."),
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
231
232
233
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."),
234
235
236
237
238
239
    )
    chat_template_kwargs: Optional[Dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the template renderer. "
                     "Will be accessible by the chat template."),
    )
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
        description=("If specified, the output will follow the JSON schema."),
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the regex pattern."),
    )
    guided_choice: Optional[List[str]] = Field(
        default=None,
        description=(
            "If specified, the output will be exactly one of the choices."),
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the context free grammar."),
    )
259
260
261
262
263
264
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default guided decoding backend "
            "of the server for this specific request. If set, must be either "
            "'outlines' / 'lm-format-enforcer'"))
265
266
267
268
269
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default whitespace pattern "
            "for guided json decoding."))
270
271
272
273
274
275
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."))
276
277
278
279
280
281
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."))
282
283

    # doc: end-chat-completion-extra-params
Zhuohan Li's avatar
Zhuohan Li committed
284

285
286
    def to_beam_search_params(self,
                              default_max_tokens: int) -> BeamSearchParams:
287
288
        # TODO(#9845): remove max_tokens when field is removed from OpenAI API
        max_tokens = self.max_completion_tokens or self.max_tokens
289
290
291
292
293
294
295
296
297
298
299
        if max_tokens is None:
            max_tokens = default_max_tokens

        n = self.n if self.n is not None else 1
        temperature = self.temperature if self.temperature is not None else 0.0

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
300
            length_penalty=self.length_penalty,
301
            include_stop_str_in_output=self.include_stop_str_in_output)
302

303
    def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
304
305
        # TODO(#9845): remove max_tokens when field is removed from OpenAI API
        max_tokens = self.max_completion_tokens or self.max_tokens
306
307
        if max_tokens is None:
            max_tokens = default_max_tokens
308

309
310
311
312
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.top_logprobs

313
        guided_json_object = None
314
315
316
317
318
319
320
321
322
        if self.response_format is not None:
            if self.response_format.type == "json_object":
                guided_json_object = True
            elif self.response_format.type == "json_schema":
                json_schema = self.response_format.json_schema
                assert json_schema is not None
                self.guided_json = json_schema.json_schema
                if self.guided_decoding_backend is None:
                    self.guided_decoding_backend = "lm-format-enforcer"
323
324
325
326
327
328
329
330
331

        guided_decoding = GuidedDecodingParams.from_optional(
            json=self._get_guided_json_from_tool() or self.guided_json,
            regex=self.guided_regex,
            choice=self.guided_choice,
            grammar=self.guided_grammar,
            json_object=guided_json_object,
            backend=self.guided_decoding_backend,
            whitespace_pattern=self.guided_whitespace_pattern)
332

333
        return SamplingParams.from_optional(
334
            n=self.n,
335
            best_of=self.best_of,
336
337
338
339
340
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=self.repetition_penalty,
            temperature=self.temperature,
            top_p=self.top_p,
341
            top_k=self.top_k,
342
            min_p=self.min_p,
Nick Hill's avatar
Nick Hill committed
343
            seed=self.seed,
344
345
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
346
            logprobs=self.top_logprobs if self.logprobs else None,
347
            prompt_logprobs=prompt_logprobs,
348
            ignore_eos=self.ignore_eos,
349
            max_tokens=max_tokens,
350
            min_tokens=self.min_tokens,
351
352
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
353
            include_stop_str_in_output=self.include_stop_str_in_output,
354
            truncate_prompt_tokens=self.truncate_prompt_tokens,
355
356
            output_kind=RequestOutputKind.DELTA if self.stream \
                else RequestOutputKind.FINAL_ONLY,
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
            guided_decoding=guided_decoding,
            logit_bias=self.logit_bias)

    def _get_guided_json_from_tool(
            self) -> Optional[Union[str, dict, BaseModel]]:
        # user has chosen to not use any tool
        if self.tool_choice == "none" or self.tools is None:
            return None

        # user has chosen to use a named tool
        if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
            tool_name = self.tool_choice.function.name
            tools = {tool.function.name: tool.function for tool in self.tools}
            if tool_name not in tools:
                raise ValueError(
                    f"Tool '{tool_name}' has not been passed in `tools`.")
            tool = tools[tool_name]
            return tool.parameters

        return None
377

378
    @model_validator(mode="before")
379
    @classmethod
380
381
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
382
            raise ValueError(
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
                "Stream options can only be defined when `stream=True`.")

        return data

    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
            if data.get("stream") and prompt_logprobs > 0:
                raise ValueError(
                    "`prompt_logprobs` are not available when `stream=True`.")

            if prompt_logprobs < 0:
                raise ValueError("`prompt_logprobs` must be a positive value.")

        if (top_logprobs := data.get("top_logprobs")) is not None:
            if top_logprobs < 0:
                raise ValueError("`top_logprobs` must be a positive value.")

            if not data.get("logprobs"):
                raise ValueError(
                    "when using `top_logprobs`, `logprobs` must be set to true."
                )

        return data
408

409
410
411
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
412
413
414
        if isinstance(data, ValueError):
            raise data

415
416
417
418
419
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
420
        # you can only use one kind of guided decoding
421
422
423
424
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
425
        # you can only either use guided decoding or tools, not both
426
427
        if guide_count > 1 and data.get("tool_choice",
                                        "none") not in ("none", "auto"):
428
429
430
431
432
433
            raise ValueError(
                "You can only either use guided decoding or tools, not both.")
        return data

    @model_validator(mode="before")
    @classmethod
434
435
436
437
    def check_tool_usage(cls, data):

        # if "tool_choice" is not specified but tools are provided,
        # default to "auto" tool_choice
438
        if "tool_choice" not in data and data.get("tools"):
439
440
            data["tool_choice"] = "auto"

441
442
443
444
445
446
        # if "tool_choice" is "none" -- ignore tools if present
        if "tool_choice" in data and data["tool_choice"] == "none":
            # ensure that no tools are present
            data.pop("tools", None)
            return data

447
448
449
450
        # if "tool_choice" is specified -- validation
        if "tool_choice" in data:

            # ensure that if "tool choice" is specified, tools are present
451
452
453
            if "tools" not in data or data["tools"] is None:
                raise ValueError(
                    "When using `tool_choice`, `tools` must be set.")
454
455
456
457
458
459

            # make sure that tool choice is either a named tool
            # OR that it's set to "auto"
            if data["tool_choice"] != "auto" and not isinstance(
                    data["tool_choice"], dict):
                raise ValueError(
460
461
                    "`tool_choice` must either be a named tool, \"auto\", "
                    "or \"none\".")
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486

            # ensure that if "tool_choice" is specified as an object,
            # it matches a valid tool
            if isinstance(data["tool_choice"], dict):
                valid_tool = False
                specified_function = data["tool_choice"]["function"]
                if not specified_function:
                    raise ValueError(
                        "Incorrectly formatted `tool_choice`. Should be like "
                        "`{\"type\": \"function\","
                        " \"function\": {\"name\": \"my_function\"}}`")
                specified_function_name = specified_function["name"]
                if not specified_function_name:
                    raise ValueError(
                        "Incorrectly formatted `tool_choice`. Should be like "
                        "`{\"type\": \"function\", "
                        "\"function\": {\"name\": \"my_function\"}}`")
                for tool in data["tools"]:
                    if tool["function"]["name"] == specified_function_name:
                        valid_tool = True
                        break
                if not valid_tool:
                    raise ValueError(
                        "The tool specified in `tool_choice` does not match any"
                        " of the specified `tools`")
487
488
        return data

489
490
491
492
493
494
495
496
497
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
        if data.get("continue_final_message") and data.get(
                "add_generation_prompt"):
            raise ValueError("Cannot set both `continue_final_message` and "
                             "`add_generation_prompt` to True.")
        return data

Zhuohan Li's avatar
Zhuohan Li committed
498

499
class CompletionRequest(OpenAIBaseModel):
500
501
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
Zhuohan Li's avatar
Zhuohan Li committed
502
    model: str
503
    prompt: Union[List[int], List[List[int]], str, List[str]]
504
    best_of: Optional[int] = None
Zhuohan Li's avatar
Zhuohan Li committed
505
506
507
    echo: Optional[bool] = False
    frequency_penalty: Optional[float] = 0.0
    logit_bias: Optional[Dict[str, float]] = None
508
509
    logprobs: Optional[int] = None
    max_tokens: Optional[int] = 16
510
    n: int = 1
511
    presence_penalty: Optional[float] = 0.0
512
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
513
514
    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
    stream: Optional[bool] = False
515
    stream_options: Optional[StreamOptions] = None
516
517
518
    suffix: Optional[str] = None
    temperature: Optional[float] = 1.0
    top_p: Optional[float] = 1.0
Zhuohan Li's avatar
Zhuohan Li committed
519
    user: Optional[str] = None
520
521

    # doc: begin-completion-sampling-params
522
523
524
525
526
    use_beam_search: bool = False
    top_k: int = -1
    min_p: float = 0.0
    repetition_penalty: float = 1.0
    length_penalty: float = 1.0
527
    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
528
529
530
531
532
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
533
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
534
    allowed_token_ids: Optional[List[int]] = None
535
    prompt_logprobs: Optional[int] = None
536
537
538
    # doc: end-completion-sampling-params

    # doc: begin-completion-extra-params
539
540
    add_special_tokens: bool = Field(
        default=True,
541
        description=(
542
543
            "If true (the default), special tokens (e.g. BOS) will be added to "
            "the prompt."),
544
545
546
547
548
    )
    response_format: Optional[ResponseFormat] = Field(
        default=None,
        description=
        ("Similar to chat completion, this parameter specifies the format of "
549
550
         "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
         "{'type': 'text' } is supported."),
551
552
553
    )
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
554
        description="If specified, the output will follow the JSON schema.",
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the regex pattern."),
    )
    guided_choice: Optional[List[str]] = Field(
        default=None,
        description=(
            "If specified, the output will be exactly one of the choices."),
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the context free grammar."),
    )
571
572
573
574
575
576
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default guided decoding backend "
            "of the server for this specific request. If set, must be one of "
            "'outlines' / 'lm-format-enforcer'"))
577
578
579
580
581
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default whitespace pattern "
            "for guided json decoding."))
582
583
584
585
586
587
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."))
588
589

    # doc: end-completion-extra-params
Zhuohan Li's avatar
Zhuohan Li committed
590

591
592
593
594
595
596
597
598
599
600
601
602
603
604
    def to_beam_search_params(self,
                              default_max_tokens: int) -> BeamSearchParams:
        max_tokens = self.max_tokens
        if max_tokens is None:
            max_tokens = default_max_tokens

        n = self.n if self.n is not None else 1
        temperature = self.temperature if self.temperature is not None else 0.0

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
605
            length_penalty=self.length_penalty,
606
            include_stop_str_in_output=self.include_stop_str_in_output)
607

608
    def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
609
610
611
612
        max_tokens = self.max_tokens
        if max_tokens is None:
            max_tokens = default_max_tokens

613
614
615
616
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

617
618
        echo_without_generation = self.echo and self.max_tokens == 0

619
620
621
622
623
624
625
626
627
628
629
630
631
        guided_json_object = None
        if (self.response_format is not None
                and self.response_format.type == "json_object"):
            guided_json_object = True

        guided_decoding = GuidedDecodingParams.from_optional(
            json=self.guided_json,
            regex=self.guided_regex,
            choice=self.guided_choice,
            grammar=self.guided_grammar,
            json_object=guided_json_object,
            backend=self.guided_decoding_backend,
            whitespace_pattern=self.guided_whitespace_pattern)
632

633
        return SamplingParams.from_optional(
634
635
636
637
638
639
640
641
642
            n=self.n,
            best_of=self.best_of,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=self.repetition_penalty,
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
            min_p=self.min_p,
Nick Hill's avatar
Nick Hill committed
643
            seed=self.seed,
644
645
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
646
            logprobs=self.logprobs,
647
            ignore_eos=self.ignore_eos,
648
            max_tokens=max_tokens if not echo_without_generation else 1,
649
            min_tokens=self.min_tokens,
650
            prompt_logprobs=prompt_logprobs,
651
            skip_special_tokens=self.skip_special_tokens,
652
            spaces_between_special_tokens=self.spaces_between_special_tokens,
653
            include_stop_str_in_output=self.include_stop_str_in_output,
654
            truncate_prompt_tokens=self.truncate_prompt_tokens,
655
656
            output_kind=RequestOutputKind.DELTA if self.stream \
                else RequestOutputKind.FINAL_ONLY,
657
658
659
            guided_decoding=guided_decoding,
            logit_bias=self.logit_bias,
            allowed_token_ids=self.allowed_token_ids)
660

661
662
663
664
665
666
667
668
669
670
671
672
673
674
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
        return data

675
676
677
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
678
679
680
681
682
683
684
685
686
687
688
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
            if data.get("stream") and prompt_logprobs > 0:
                raise ValueError(
                    "`prompt_logprobs` are not available when `stream=True`.")

            if prompt_logprobs < 0:
                raise ValueError("`prompt_logprobs` must be a positive value.")

        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
            raise ValueError("`logprobs` must be a positive value.")

689
690
        return data

691
692
693
694
695
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
            raise ValueError(
696
697
                "Stream options can only be defined when `stream=True`.")

698
699
        return data

Zhuohan Li's avatar
Zhuohan Li committed
700

701
class EmbeddingCompletionRequest(OpenAIBaseModel):
702
703
704
705
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/embeddings
    model: str
    input: Union[List[int], List[List[int]], str, List[str]]
706
    encoding_format: Literal["float", "base64"] = "float"
707
708
    dimensions: Optional[int] = None
    user: Optional[str] = None
709
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
710
711
712
713
714

    # doc: begin-embedding-pooling-params
    additional_data: Optional[Any] = None
    # doc: end-embedding-pooling-params

715
    # doc: begin-embedding-extra-params
716
717
718
719
720
721
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
            "the prompt."),
    )
722
723
724
725
726
727
728
729
730
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."))

    # doc: end-embedding-extra-params

731
732
733
734
    def to_pooling_params(self):
        return PoolingParams(additional_data=self.additional_data)


735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
class EmbeddingChatRequest(OpenAIBaseModel):
    model: str
    messages: List[ChatCompletionMessageParam]

    encoding_format: Literal["float", "base64"] = "float"
    dimensions: Optional[int] = None
    user: Optional[str] = None
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None

    # doc: begin-chat-embedding-pooling-params
    additional_data: Optional[Any] = None
    # doc: end-chat-embedding-pooling-params

    # doc: begin-chat-embedding-extra-params
    add_generation_prompt: bool = Field(
        default=True,
        description=
        ("If true, the generation prompt will be added to the chat template. "
         "This is a parameter used by chat template in tokenizer config of the "
         "model."),
    )
    continue_final_message: bool = Field(
        default=False,
        description=
        ("If this is set, the chat will be formatted so that the final "
         "message in the chat is open-ended, without any EOS tokens. The "
         "model will continue this message rather than starting a new one. "
         "This allows you to \"prefill\" part of the model's response for it. "
         "Cannot be used at the same time as `add_generation_prompt`."),
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
            "default)."),
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."),
    )
    chat_template_kwargs: Optional[Dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the template renderer. "
                     "Will be accessible by the chat template."),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."))
    # doc: end-chat-embedding-extra-params

    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
        if data.get("continue_final_message") and data.get(
                "add_generation_prompt"):
            raise ValueError("Cannot set both `continue_final_message` and "
                             "`add_generation_prompt` to True.")
        return data

    def to_pooling_params(self):
        return PoolingParams(additional_data=self.additional_data)


EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]


811
class CompletionLogProbs(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
812
813
814
    text_offset: List[int] = Field(default_factory=list)
    token_logprobs: List[Optional[float]] = Field(default_factory=list)
    tokens: List[str] = Field(default_factory=list)
815
816
    top_logprobs: List[Optional[Dict[str,
                                     float]]] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
817
818


819
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
820
821
    index: int
    text: str
822
    logprobs: Optional[CompletionLogProbs] = None
823
824
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
825
826
827
828
829
830
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
            "including encountering the EOS token"),
    )
831
    prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
Zhuohan Li's avatar
Zhuohan Li committed
832
833


834
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
835
836
837
838
839
840
841
842
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[CompletionResponseChoice]
    usage: UsageInfo


843
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
844
845
    index: int
    text: str
846
    logprobs: Optional[CompletionLogProbs] = None
847
848
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
849
850
851
852
853
854
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
            "including encountering the EOS token"),
    )
Zhuohan Li's avatar
Zhuohan Li committed
855
856


857
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
858
859
860
861
862
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[CompletionResponseStreamChoice]
863
    usage: Optional[UsageInfo] = Field(default=None)
864
865


866
class EmbeddingResponseData(OpenAIBaseModel):
867
868
    index: int
    object: str = "embedding"
869
    embedding: Union[List[float], str]
870
871


872
class EmbeddingResponse(OpenAIBaseModel):
873
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
874
875
876
877
878
879
880
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: List[EmbeddingResponseData]
    usage: UsageInfo


881
882
883
884
885
886
887
888
889
890
891
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
    type: Literal["function"] = "function"
    function: FunctionCall


892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
class DeltaFunctionCall(BaseModel):
    name: Optional[str] = None
    arguments: Optional[str] = None


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
    type: Literal["function"] = "function"
    index: int
    function: Optional[DeltaFunctionCall] = None


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
    tool_calls: List[ToolCall]

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
    content: Optional[str] = None


917
class ChatMessage(OpenAIBaseModel):
918
    role: str
919
    content: Optional[str] = None
920
    tool_calls: List[ToolCall] = Field(default_factory=list)
921
922


923
924
925
926
927
928
929
930
931
932
933
934
935
936
class ChatCompletionLogProb(OpenAIBaseModel):
    token: str
    logprob: float = -9999.0
    bytes: Optional[List[int]] = None


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
    top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)


class ChatCompletionLogProbs(OpenAIBaseModel):
    content: Optional[List[ChatCompletionLogProbsContent]] = None


937
class ChatCompletionResponseChoice(OpenAIBaseModel):
938
939
    index: int
    message: ChatMessage
940
    logprobs: Optional[ChatCompletionLogProbs] = None
941
942
943
    # per OpenAI spec this is the default
    finish_reason: Optional[str] = "stop"
    # not part of the OpenAI spec but included in vLLM for legacy reasons
944
    stop_reason: Optional[Union[int, str]] = None
945
946


947
class ChatCompletionResponse(OpenAIBaseModel):
948
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
949
    object: Literal["chat.completion"] = "chat.completion"
950
951
952
953
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionResponseChoice]
    usage: UsageInfo
954
    prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
955
956


957
class DeltaMessage(OpenAIBaseModel):
958
959
    role: Optional[str] = None
    content: Optional[str] = None
960
    tool_calls: List[DeltaToolCall] = Field(default_factory=list)
961
962


963
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
964
965
    index: int
    delta: DeltaMessage
966
    logprobs: Optional[ChatCompletionLogProbs] = None
967
    finish_reason: Optional[str] = None
968
    stop_reason: Optional[Union[int, str]] = None
969
970


971
class ChatCompletionStreamResponse(OpenAIBaseModel):
972
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
973
    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
974
975
976
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionResponseStreamChoice]
977
    usage: Optional[UsageInfo] = Field(default=None)
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998


class BatchRequestInput(OpenAIBaseModel):
    """
    The per-line object of the batch input file.

    NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
    """

    # A developer-provided per-request id that will be used to match outputs to
    # inputs. Must be unique for each request in a batch.
    custom_id: str

    # The HTTP method to be used for the request. Currently only POST is
    # supported.
    method: str

    # The OpenAI API relative URL to be used for the request. Currently
    # /v1/chat/completions is supported.
    url: str

999
    # The parameters of the request.
1000
    body: Union[ChatCompletionRequest, EmbeddingRequest]
1001
1002


1003
1004
1005
1006
1007
1008
1009
1010
class BatchResponseData(OpenAIBaseModel):
    # HTTP status code of the response.
    status_code: int = 200

    # An unique identifier for the API request.
    request_id: str

    # The body of the response.
1011
    body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
1012
1013


1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
class BatchRequestOutput(OpenAIBaseModel):
    """
    The per-line object of the batch output and error files
    """

    id: str

    # A developer-provided per-request id that will be used to match outputs to
    # inputs.
    custom_id: str

1025
    response: Optional[BatchResponseData]
1026
1027
1028
1029

    # For requests that failed with a non-HTTP error, this will contain more
    # information on the cause of the failure.
    error: Optional[Any]
1030
1031


1032
1033
1034
1035
class TokenizeCompletionRequest(OpenAIBaseModel):
    model: str
    prompt: str

1036
1037
1038
1039
1040
1041
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
            "the prompt."),
    )
1042
1043
1044
1045
1046
1047


class TokenizeChatRequest(OpenAIBaseModel):
    model: str
    messages: List[ChatCompletionMessageParam]

1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
    add_generation_prompt: bool = Field(
        default=True,
        description=
        ("If true, the generation prompt will be added to the chat template. "
         "This is a parameter used by chat template in tokenizer config of the "
         "model."),
    )
    continue_final_message: bool = Field(
        default=False,
        description=
        ("If this is set, the chat will be formatted so that the final "
         "message in the chat is open-ended, without any EOS tokens. The "
         "model will continue this message rather than starting a new one. "
         "This allows you to \"prefill\" part of the model's response for it. "
         "Cannot be used at the same time as `add_generation_prompt`."),
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
            "default)."),
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."),
    )
    chat_template_kwargs: Optional[Dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the template renderer. "
                     "Will be accessible by the chat template."),
    )
1086

1087
1088
1089
1090
1091
1092
1093
1094
1095
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
        if data.get("continue_final_message") and data.get(
                "add_generation_prompt"):
            raise ValueError("Cannot set both `continue_final_message` and "
                             "`add_generation_prompt` to True.")
        return data

1096
1097

TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
1098
1099
1100
1101
1102


class TokenizeResponse(OpenAIBaseModel):
    count: int
    max_model_len: int
1103
    tokens: List[int]
1104
1105
1106
1107
1108
1109
1110
1111
1112


class DetokenizeRequest(OpenAIBaseModel):
    model: str
    tokens: List[int]


class DetokenizeResponse(OpenAIBaseModel):
    prompt: str
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122


class LoadLoraAdapterRequest(BaseModel):
    lora_name: str
    lora_path: str


class UnloadLoraAdapterRequest(BaseModel):
    lora_name: str
    lora_int_id: Optional[int] = Field(default=None)