protocol.py 44.9 KB
Newer Older
1
2
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
3
import re
Zhuohan Li's avatar
Zhuohan Li committed
4
import time
5
from argparse import Namespace
6
from typing import Any, Dict, List, Literal, Optional, Union
Zhuohan Li's avatar
Zhuohan Li committed
7

8
import torch
9
from pydantic import BaseModel, ConfigDict, Field, model_validator
10
from typing_extensions import Annotated
Zhuohan Li's avatar
Zhuohan Li committed
11

12
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
13
from vllm.logger import init_logger
14
from vllm.pooling_params import PoolingParams
15
16
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
17
from vllm.sequence import Logprob
18
from vllm.utils import random_uuid, resolve_obj_by_qualname
19

20
21
logger = init_logger(__name__)

22
23
24
# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
25
_LONG_INFO: Union["torch.iinfo", Namespace]
26
27
28
29
30
31
32
33
34
35
36
37
38
39

try:
    from sphinx.ext.autodoc.mock import _MockModule

    if isinstance(torch, _MockModule):
        _LONG_INFO = _MOCK_LONG_INFO
    else:
        _LONG_INFO = torch.iinfo(torch.long)
except ModuleNotFoundError:
    _LONG_INFO = torch.iinfo(torch.long)

assert _LONG_INFO.min == _MOCK_LONG_INFO.min
assert _LONG_INFO.max == _MOCK_LONG_INFO.max

Zhuohan Li's avatar
Zhuohan Li committed
40

41
class OpenAIBaseModel(BaseModel):
42
43
44
45
46
47
48
49
50
51
52
53
54
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

    @model_validator(mode="before")
    @classmethod
    def __log_extra_fields__(cls, data):
        if isinstance(data, dict):
            extra_fields = data.keys() - cls.model_fields.keys()
            if extra_fields:
                logger.warning(
                    "The following fields were present in the request "
                    "but ignored: %s", extra_fields)
        return data
55
56
57


class ErrorResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
58
59
60
61
    object: str = "error"
    message: str
    type: str
    param: Optional[str] = None
62
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
63
64


65
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
66
67
68
69
70
71
72
73
74
75
76
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
    group: Optional[str] = None
77
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
78
79


80
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
81
82
83
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
84
    owned_by: str = "vllm"
Zhuohan Li's avatar
Zhuohan Li committed
85
86
    root: Optional[str] = None
    parent: Optional[str] = None
87
    max_model_len: Optional[int] = None
Zhuohan Li's avatar
Zhuohan Li committed
88
89
90
    permission: List[ModelPermission] = Field(default_factory=list)


91
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
92
93
94
95
    object: str = "list"
    data: List[ModelCard] = Field(default_factory=list)


96
97
98
99
class PromptTokenUsageInfo(OpenAIBaseModel):
    cached_tokens: Optional[int] = None


100
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
101
102
103
    prompt_tokens: int = 0
    total_tokens: int = 0
    completion_tokens: Optional[int] = 0
104
    prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
Zhuohan Li's avatar
Zhuohan Li committed
105
106


107
108
109
110
111
class RequestResponseMetadata(BaseModel):
    request_id: str
    final_usage_info: Optional[UsageInfo] = None


112
113
114
115
116
117
118
119
120
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
    description: Optional[str] = None
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
    json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
    strict: Optional[bool] = None


121
class ResponseFormat(OpenAIBaseModel):
122
123
124
    # type must be "json_schema", "json_object" or "text"
    type: Literal["text", "json_object", "json_schema"]
    json_schema: Optional[JsonSchemaResponseFormat] = None
125
126


127
class StreamOptions(OpenAIBaseModel):
128
    include_usage: Optional[bool] = True
129
    continuous_usage_stats: Optional[bool] = False
130
131


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class FunctionDefinition(OpenAIBaseModel):
    name: str
    description: Optional[str] = None
    parameters: Optional[Dict[str, Any]] = None


class ChatCompletionToolsParam(OpenAIBaseModel):
    type: Literal["function"] = "function"
    function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
    name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
    function: ChatCompletionNamedFunction
    type: Literal["function"] = "function"


152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
class LogitsProcessorConstructor(BaseModel):
    qualname: str
    args: Optional[List[Any]] = None
    kwargs: Optional[Dict[str, Any]] = None


LogitsProcessors = List[Union[str, LogitsProcessorConstructor]]


def get_logits_processors(processors: Optional[LogitsProcessors],
                          pattern: Optional[str]) -> Optional[List[Any]]:
    if processors and pattern:
        logits_processors = []
        for processor in processors:
            qualname = processor if isinstance(processor,
                                               str) else processor.qualname
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
                    "for more information.")
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
                logits_processor = logits_processor(*processor.args or [],
                                                    **processor.kwargs or {})
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
            "server. See --logits-processor-pattern engine argugment "
            "for more information.")
    return None


192
class ChatCompletionRequest(OpenAIBaseModel):
193
194
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/chat/create
195
    messages: List[ChatCompletionMessageParam]
196
197
198
199
    model: str
    frequency_penalty: Optional[float] = 0.0
    logit_bias: Optional[Dict[str, float]] = None
    logprobs: Optional[bool] = False
200
    top_logprobs: Optional[int] = 0
201
202
203
204
205
206
    # TODO(#9845): remove max_tokens when field is removed from OpenAI API
    max_tokens: Optional[int] = Field(
        default=None,
        deprecated=
        'max_tokens is deprecated in favor of the max_completion_tokens field')
    max_completion_tokens: Optional[int] = None
207
208
209
    n: Optional[int] = 1
    presence_penalty: Optional[float] = 0.0
    response_format: Optional[ResponseFormat] = None
210
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
211
    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
212
    stream: Optional[bool] = False
213
    stream_options: Optional[StreamOptions] = None
214
    temperature: Optional[float] = 1.0
215
    top_p: Optional[float] = 1.0
216
    tools: Optional[List[ChatCompletionToolsParam]] = None
217
    tool_choice: Optional[Union[Literal["none"], Literal["auto"],
218
                                ChatCompletionNamedToolChoiceParam]] = "none"
219
220
221

    # NOTE this will be ignored by VLLM -- the model determines the behavior
    parallel_tool_calls: Optional[bool] = False
Zhuohan Li's avatar
Zhuohan Li committed
222
    user: Optional[str] = None
223
224

    # doc: begin-chat-completion-sampling-params
225
    best_of: Optional[int] = None
226
227
228
229
230
    use_beam_search: bool = False
    top_k: int = -1
    min_p: float = 0.0
    repetition_penalty: float = 1.0
    length_penalty: float = 1.0
231
    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
232
233
234
235
236
237
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
238
    prompt_logprobs: Optional[int] = None
239
240
241
    # doc: end-chat-completion-sampling-params

    # doc: begin-chat-completion-extra-params
242
    echo: bool = Field(
243
244
245
246
247
        default=False,
        description=(
            "If true, the new message will be prepended with the last message "
            "if they belong to the same role."),
    )
248
    add_generation_prompt: bool = Field(
249
250
251
252
253
254
        default=True,
        description=
        ("If true, the generation prompt will be added to the chat template. "
         "This is a parameter used by chat template in tokenizer config of the "
         "model."),
    )
255
256
257
258
259
260
261
262
263
    continue_final_message: bool = Field(
        default=False,
        description=
        ("If this is set, the chat will be formatted so that the final "
         "message in the chat is open-ended, without any EOS tokens. The "
         "model will continue this message rather than starting a new one. "
         "This allows you to \"prefill\" part of the model's response for it. "
         "Cannot be used at the same time as `add_generation_prompt`."),
    )
264
    add_special_tokens: bool = Field(
265
266
267
268
269
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
270
            "special tokens so this should be set to false (as is the "
271
272
            "default)."),
    )
273
274
275
276
277
278
279
280
281
282
283
284
285
    documents: Optional[List[Dict[str, str]]] = Field(
        default=None,
        description=
        ("A list of dicts representing documents that will be accessible to "
         "the model if it is performing RAG (retrieval-augmented generation)."
         " If the template does not support RAG, this argument will have no "
         "effect. We recommend that each document should be a dict containing "
         "\"title\" and \"text\" keys."),
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
286
287
288
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."),
289
290
291
292
293
294
    )
    chat_template_kwargs: Optional[Dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the template renderer. "
                     "Will be accessible by the chat template."),
    )
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
        description=("If specified, the output will follow the JSON schema."),
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the regex pattern."),
    )
    guided_choice: Optional[List[str]] = Field(
        default=None,
        description=(
            "If specified, the output will be exactly one of the choices."),
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the context free grammar."),
    )
314
315
316
317
318
319
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default guided decoding backend "
            "of the server for this specific request. If set, must be either "
            "'outlines' / 'lm-format-enforcer'"))
320
321
322
323
324
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default whitespace pattern "
            "for guided json decoding."))
325
326
327
328
329
330
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."))
331
332
333
334
335
336
    request_id: str = Field(
        default_factory=lambda: f"{random_uuid()}",
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."))
337
338
339
340
341
342
343
344
345
346
347
    logits_processors: Optional[LogitsProcessors] = Field(
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
            "{'param': 'value'}}."))
348
349

    # doc: end-chat-completion-extra-params
Zhuohan Li's avatar
Zhuohan Li committed
350

351
352
    def to_beam_search_params(self,
                              default_max_tokens: int) -> BeamSearchParams:
353
354
        # TODO(#9845): remove max_tokens when field is removed from OpenAI API
        max_tokens = self.max_completion_tokens or self.max_tokens
355
356
357
358
359
360
361
362
363
364
365
        if max_tokens is None:
            max_tokens = default_max_tokens

        n = self.n if self.n is not None else 1
        temperature = self.temperature if self.temperature is not None else 0.0

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
366
            length_penalty=self.length_penalty,
367
            include_stop_str_in_output=self.include_stop_str_in_output)
368

369
370
371
    def to_sampling_params(
            self, default_max_tokens: int,
            logits_processor_pattern: Optional[str]) -> SamplingParams:
372
373
        # TODO(#9845): remove max_tokens when field is removed from OpenAI API
        max_tokens = self.max_completion_tokens or self.max_tokens
374
375
        if max_tokens is None:
            max_tokens = default_max_tokens
376

377
378
379
380
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.top_logprobs

381
        guided_json_object = None
382
383
384
385
386
387
388
389
        if self.response_format is not None:
            if self.response_format.type == "json_object":
                guided_json_object = True
            elif self.response_format.type == "json_schema":
                json_schema = self.response_format.json_schema
                assert json_schema is not None
                self.guided_json = json_schema.json_schema
                if self.guided_decoding_backend is None:
390
                    self.guided_decoding_backend = "xgrammar"
391
392
393
394
395
396
397
398
399

        guided_decoding = GuidedDecodingParams.from_optional(
            json=self._get_guided_json_from_tool() or self.guided_json,
            regex=self.guided_regex,
            choice=self.guided_choice,
            grammar=self.guided_grammar,
            json_object=guided_json_object,
            backend=self.guided_decoding_backend,
            whitespace_pattern=self.guided_whitespace_pattern)
400

401
        return SamplingParams.from_optional(
402
            n=self.n,
403
            best_of=self.best_of,
404
405
406
407
408
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=self.repetition_penalty,
            temperature=self.temperature,
            top_p=self.top_p,
409
            top_k=self.top_k,
410
            min_p=self.min_p,
Nick Hill's avatar
Nick Hill committed
411
            seed=self.seed,
412
413
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
414
            logprobs=self.top_logprobs if self.logprobs else None,
415
            prompt_logprobs=prompt_logprobs,
416
            ignore_eos=self.ignore_eos,
417
            max_tokens=max_tokens,
418
            min_tokens=self.min_tokens,
419
420
            skip_special_tokens=self.skip_special_tokens,
            spaces_between_special_tokens=self.spaces_between_special_tokens,
421
422
            logits_processors=get_logits_processors(self.logits_processors,
                                                    logits_processor_pattern),
423
            include_stop_str_in_output=self.include_stop_str_in_output,
424
            truncate_prompt_tokens=self.truncate_prompt_tokens,
425
426
            output_kind=RequestOutputKind.DELTA if self.stream \
                else RequestOutputKind.FINAL_ONLY,
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
            guided_decoding=guided_decoding,
            logit_bias=self.logit_bias)

    def _get_guided_json_from_tool(
            self) -> Optional[Union[str, dict, BaseModel]]:
        # user has chosen to not use any tool
        if self.tool_choice == "none" or self.tools is None:
            return None

        # user has chosen to use a named tool
        if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
            tool_name = self.tool_choice.function.name
            tools = {tool.function.name: tool.function for tool in self.tools}
            if tool_name not in tools:
                raise ValueError(
                    f"Tool '{tool_name}' has not been passed in `tools`.")
            tool = tools[tool_name]
            return tool.parameters

        return None
447

448
    @model_validator(mode="before")
449
    @classmethod
450
451
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
452
            raise ValueError(
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
                "Stream options can only be defined when `stream=True`.")

        return data

    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
            if data.get("stream") and prompt_logprobs > 0:
                raise ValueError(
                    "`prompt_logprobs` are not available when `stream=True`.")

            if prompt_logprobs < 0:
                raise ValueError("`prompt_logprobs` must be a positive value.")

        if (top_logprobs := data.get("top_logprobs")) is not None:
            if top_logprobs < 0:
                raise ValueError("`top_logprobs` must be a positive value.")

            if not data.get("logprobs"):
                raise ValueError(
                    "when using `top_logprobs`, `logprobs` must be set to true."
                )

        return data
478

479
480
481
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
482
483
484
        if isinstance(data, ValueError):
            raise data

485
486
487
488
489
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
490
        # you can only use one kind of guided decoding
491
492
493
494
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
495
        # you can only either use guided decoding or tools, not both
496
497
        if guide_count > 1 and data.get("tool_choice",
                                        "none") not in ("none", "auto"):
498
499
500
501
502
503
            raise ValueError(
                "You can only either use guided decoding or tools, not both.")
        return data

    @model_validator(mode="before")
    @classmethod
504
505
506
507
    def check_tool_usage(cls, data):

        # if "tool_choice" is not specified but tools are provided,
        # default to "auto" tool_choice
508
        if "tool_choice" not in data and data.get("tools"):
509
510
            data["tool_choice"] = "auto"

511
512
513
514
515
516
        # if "tool_choice" is "none" -- ignore tools if present
        if "tool_choice" in data and data["tool_choice"] == "none":
            # ensure that no tools are present
            data.pop("tools", None)
            return data

517
518
519
520
        # if "tool_choice" is specified -- validation
        if "tool_choice" in data:

            # ensure that if "tool choice" is specified, tools are present
521
522
523
            if "tools" not in data or data["tools"] is None:
                raise ValueError(
                    "When using `tool_choice`, `tools` must be set.")
524
525
526
527
528
529

            # make sure that tool choice is either a named tool
            # OR that it's set to "auto"
            if data["tool_choice"] != "auto" and not isinstance(
                    data["tool_choice"], dict):
                raise ValueError(
530
531
                    "`tool_choice` must either be a named tool, \"auto\", "
                    "or \"none\".")
532
533
534
535
536

            # ensure that if "tool_choice" is specified as an object,
            # it matches a valid tool
            if isinstance(data["tool_choice"], dict):
                valid_tool = False
537
                specified_function = data["tool_choice"].get("function")
538
539
                if not specified_function:
                    raise ValueError(
540
541
                        "Expected field `function` in `tool_choice`."
                        " Correct usage: `{\"type\": \"function\","
542
                        " \"function\": {\"name\": \"my_function\"}}`")
543
                specified_function_name = specified_function.get("name")
544
545
                if not specified_function_name:
                    raise ValueError(
546
547
                        "Expected field `name` in `function` in `tool_choice`."
                        "Correct usage: `{\"type\": \"function\", "
548
549
550
551
552
553
554
555
556
                        "\"function\": {\"name\": \"my_function\"}}`")
                for tool in data["tools"]:
                    if tool["function"]["name"] == specified_function_name:
                        valid_tool = True
                        break
                if not valid_tool:
                    raise ValueError(
                        "The tool specified in `tool_choice` does not match any"
                        " of the specified `tools`")
557
558
        return data

559
560
561
562
563
564
565
566
567
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
        if data.get("continue_final_message") and data.get(
                "add_generation_prompt"):
            raise ValueError("Cannot set both `continue_final_message` and "
                             "`add_generation_prompt` to True.")
        return data

Zhuohan Li's avatar
Zhuohan Li committed
568

569
class CompletionRequest(OpenAIBaseModel):
570
571
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
Zhuohan Li's avatar
Zhuohan Li committed
572
    model: str
573
    prompt: Union[List[int], List[List[int]], str, List[str]]
574
    best_of: Optional[int] = None
Zhuohan Li's avatar
Zhuohan Li committed
575
576
577
    echo: Optional[bool] = False
    frequency_penalty: Optional[float] = 0.0
    logit_bias: Optional[Dict[str, float]] = None
578
579
    logprobs: Optional[int] = None
    max_tokens: Optional[int] = 16
580
    n: int = 1
581
    presence_penalty: Optional[float] = 0.0
582
    seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
583
584
    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
    stream: Optional[bool] = False
585
    stream_options: Optional[StreamOptions] = None
586
587
588
    suffix: Optional[str] = None
    temperature: Optional[float] = 1.0
    top_p: Optional[float] = 1.0
Zhuohan Li's avatar
Zhuohan Li committed
589
    user: Optional[str] = None
590
591

    # doc: begin-completion-sampling-params
592
593
594
595
596
    use_beam_search: bool = False
    top_k: int = -1
    min_p: float = 0.0
    repetition_penalty: float = 1.0
    length_penalty: float = 1.0
597
    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
598
599
600
601
602
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
603
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
604
    allowed_token_ids: Optional[List[int]] = None
605
    prompt_logprobs: Optional[int] = None
606
607
608
    # doc: end-completion-sampling-params

    # doc: begin-completion-extra-params
609
610
    add_special_tokens: bool = Field(
        default=True,
611
        description=(
612
613
            "If true (the default), special tokens (e.g. BOS) will be added to "
            "the prompt."),
614
615
616
617
618
    )
    response_format: Optional[ResponseFormat] = Field(
        default=None,
        description=
        ("Similar to chat completion, this parameter specifies the format of "
619
620
         "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
         "{'type': 'text' } is supported."),
621
622
623
    )
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
        default=None,
624
        description="If specified, the output will follow the JSON schema.",
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
    )
    guided_regex: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the regex pattern."),
    )
    guided_choice: Optional[List[str]] = Field(
        default=None,
        description=(
            "If specified, the output will be exactly one of the choices."),
    )
    guided_grammar: Optional[str] = Field(
        default=None,
        description=(
            "If specified, the output will follow the context free grammar."),
    )
641
642
643
644
645
646
    guided_decoding_backend: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default guided decoding backend "
            "of the server for this specific request. If set, must be one of "
            "'outlines' / 'lm-format-enforcer'"))
647
648
649
650
651
    guided_whitespace_pattern: Optional[str] = Field(
        default=None,
        description=(
            "If specified, will override the default whitespace pattern "
            "for guided json decoding."))
652
653
654
655
656
657
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."))
658
659
660
661
662
663
664
665
666
667
668
    logits_processors: Optional[LogitsProcessors] = Field(
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
            "{'param': 'value'}}."))
669
670

    # doc: end-completion-extra-params
Zhuohan Li's avatar
Zhuohan Li committed
671

672
673
674
675
676
677
678
679
680
681
682
683
684
685
    def to_beam_search_params(self,
                              default_max_tokens: int) -> BeamSearchParams:
        max_tokens = self.max_tokens
        if max_tokens is None:
            max_tokens = default_max_tokens

        n = self.n if self.n is not None else 1
        temperature = self.temperature if self.temperature is not None else 0.0

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
686
            length_penalty=self.length_penalty,
687
            include_stop_str_in_output=self.include_stop_str_in_output)
688

689
690
691
    def to_sampling_params(
            self, default_max_tokens: int,
            logits_processor_pattern: Optional[str]) -> SamplingParams:
692
693
694
695
        max_tokens = self.max_tokens
        if max_tokens is None:
            max_tokens = default_max_tokens

696
697
698
699
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

700
701
        echo_without_generation = self.echo and self.max_tokens == 0

702
703
704
705
706
707
708
709
710
711
712
713
714
        guided_json_object = None
        if (self.response_format is not None
                and self.response_format.type == "json_object"):
            guided_json_object = True

        guided_decoding = GuidedDecodingParams.from_optional(
            json=self.guided_json,
            regex=self.guided_regex,
            choice=self.guided_choice,
            grammar=self.guided_grammar,
            json_object=guided_json_object,
            backend=self.guided_decoding_backend,
            whitespace_pattern=self.guided_whitespace_pattern)
715

716
        return SamplingParams.from_optional(
717
718
719
720
721
722
723
724
725
            n=self.n,
            best_of=self.best_of,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
            repetition_penalty=self.repetition_penalty,
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
            min_p=self.min_p,
Nick Hill's avatar
Nick Hill committed
726
            seed=self.seed,
727
728
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
729
            logprobs=self.logprobs,
730
            ignore_eos=self.ignore_eos,
731
            max_tokens=max_tokens if not echo_without_generation else 1,
732
            min_tokens=self.min_tokens,
733
            prompt_logprobs=prompt_logprobs,
734
            skip_special_tokens=self.skip_special_tokens,
735
            spaces_between_special_tokens=self.spaces_between_special_tokens,
736
            include_stop_str_in_output=self.include_stop_str_in_output,
737
738
            logits_processors=get_logits_processors(self.logits_processors,
                                                    logits_processor_pattern),
739
            truncate_prompt_tokens=self.truncate_prompt_tokens,
740
741
            output_kind=RequestOutputKind.DELTA if self.stream \
                else RequestOutputKind.FINAL_ONLY,
742
743
744
            guided_decoding=guided_decoding,
            logit_bias=self.logit_bias,
            allowed_token_ids=self.allowed_token_ids)
745

746
747
748
749
750
751
752
753
754
755
756
757
758
759
    @model_validator(mode="before")
    @classmethod
    def check_guided_decoding_count(cls, data):
        guide_count = sum([
            "guided_json" in data and data["guided_json"] is not None,
            "guided_regex" in data and data["guided_regex"] is not None,
            "guided_choice" in data and data["guided_choice"] is not None
        ])
        if guide_count > 1:
            raise ValueError(
                "You can only use one kind of guided decoding "
                "('guided_json', 'guided_regex' or 'guided_choice').")
        return data

760
761
762
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
763
764
765
766
767
768
769
770
771
772
773
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
            if data.get("stream") and prompt_logprobs > 0:
                raise ValueError(
                    "`prompt_logprobs` are not available when `stream=True`.")

            if prompt_logprobs < 0:
                raise ValueError("`prompt_logprobs` must be a positive value.")

        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
            raise ValueError("`logprobs` must be a positive value.")

774
775
        return data

776
777
778
779
780
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
            raise ValueError(
781
782
                "Stream options can only be defined when `stream=True`.")

783
784
        return data

Zhuohan Li's avatar
Zhuohan Li committed
785

786
class EmbeddingCompletionRequest(OpenAIBaseModel):
787
788
789
790
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/embeddings
    model: str
    input: Union[List[int], List[List[int]], str, List[str]]
791
    encoding_format: Literal["float", "base64"] = "float"
792
793
    dimensions: Optional[int] = None
    user: Optional[str] = None
794
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
795
796
797
798
799

    # doc: begin-embedding-pooling-params
    additional_data: Optional[Any] = None
    # doc: end-embedding-pooling-params

800
    # doc: begin-embedding-extra-params
801
802
803
804
805
806
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
            "the prompt."),
    )
807
808
809
810
811
812
813
814
815
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."))

    # doc: end-embedding-extra-params

816
817
818
819
    def to_pooling_params(self):
        return PoolingParams(additional_data=self.additional_data)


820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
class EmbeddingChatRequest(OpenAIBaseModel):
    model: str
    messages: List[ChatCompletionMessageParam]

    encoding_format: Literal["float", "base64"] = "float"
    dimensions: Optional[int] = None
    user: Optional[str] = None
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None

    # doc: begin-chat-embedding-pooling-params
    additional_data: Optional[Any] = None
    # doc: end-chat-embedding-pooling-params

    # doc: begin-chat-embedding-extra-params
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
            "default)."),
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."),
    )
    chat_template_kwargs: Optional[Dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the template renderer. "
                     "Will be accessible by the chat template."),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."))
    # doc: end-chat-embedding-extra-params

    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
        if data.get("continue_final_message") and data.get(
                "add_generation_prompt"):
            raise ValueError("Cannot set both `continue_final_message` and "
                             "`add_generation_prompt` to True.")
        return data

    def to_pooling_params(self):
        return PoolingParams(additional_data=self.additional_data)


EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]


880
881
882
883
884
885
class ScoreRequest(OpenAIBaseModel):
    model: str
    text_1: Union[List[str], str]
    text_2: Union[List[str], str]
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None

886
    # doc: begin-score-pooling-params
887
    additional_data: Optional[Any] = None
888
    # doc: end-score-pooling-params
889

890
    # doc: begin-score-extra-params
891
892
893
894
895
896
897
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."))

898
899
    # doc: end-score-extra-params

900
901
902
903
    def to_pooling_params(self):
        return PoolingParams(additional_data=self.additional_data)


904
class CompletionLogProbs(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
905
906
907
    text_offset: List[int] = Field(default_factory=list)
    token_logprobs: List[Optional[float]] = Field(default_factory=list)
    tokens: List[str] = Field(default_factory=list)
908
909
    top_logprobs: List[Optional[Dict[str,
                                     float]]] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
910
911


912
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
913
914
    index: int
    text: str
915
    logprobs: Optional[CompletionLogProbs] = None
916
917
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
918
919
920
921
922
923
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
            "including encountering the EOS token"),
    )
924
    prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
Zhuohan Li's avatar
Zhuohan Li committed
925
926


927
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
928
929
930
931
932
933
934
935
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[CompletionResponseChoice]
    usage: UsageInfo


936
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
937
938
    index: int
    text: str
939
    logprobs: Optional[CompletionLogProbs] = None
940
941
    finish_reason: Optional[str] = None
    stop_reason: Optional[Union[int, str]] = Field(
942
943
944
945
946
947
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
            "including encountering the EOS token"),
    )
Zhuohan Li's avatar
Zhuohan Li committed
948
949


950
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
951
952
953
954
955
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[CompletionResponseStreamChoice]
956
    usage: Optional[UsageInfo] = Field(default=None)
957
958


959
class EmbeddingResponseData(OpenAIBaseModel):
960
961
    index: int
    object: str = "embedding"
962
    embedding: Union[List[float], str]
963
964


965
class EmbeddingResponse(OpenAIBaseModel):
966
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
967
968
969
970
971
972
973
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: List[EmbeddingResponseData]
    usage: UsageInfo


974
975
976
class ScoreResponseData(OpenAIBaseModel):
    index: int
    object: str = "score"
977
    score: float
978
979
980
981
982
983
984
985
986
987
988


class ScoreResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: List[ScoreResponseData]
    usage: UsageInfo


989
990
991
992
993
994
995
996
997
998
999
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
    type: Literal["function"] = "function"
    function: FunctionCall


1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
class DeltaFunctionCall(BaseModel):
    name: Optional[str] = None
    arguments: Optional[str] = None


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
    type: Literal["function"] = "function"
    index: int
    function: Optional[DeltaFunctionCall] = None


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
    tool_calls: List[ToolCall]

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
    content: Optional[str] = None


1025
class ChatMessage(OpenAIBaseModel):
1026
    role: str
1027
    content: Optional[str] = None
1028
    tool_calls: List[ToolCall] = Field(default_factory=list)
1029
1030


1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
class ChatCompletionLogProb(OpenAIBaseModel):
    token: str
    logprob: float = -9999.0
    bytes: Optional[List[int]] = None


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
    top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)


class ChatCompletionLogProbs(OpenAIBaseModel):
    content: Optional[List[ChatCompletionLogProbsContent]] = None


1045
class ChatCompletionResponseChoice(OpenAIBaseModel):
1046
1047
    index: int
    message: ChatMessage
1048
    logprobs: Optional[ChatCompletionLogProbs] = None
1049
1050
1051
    # per OpenAI spec this is the default
    finish_reason: Optional[str] = "stop"
    # not part of the OpenAI spec but included in vLLM for legacy reasons
1052
    stop_reason: Optional[Union[int, str]] = None
1053
1054


1055
class ChatCompletionResponse(OpenAIBaseModel):
1056
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1057
    object: Literal["chat.completion"] = "chat.completion"
1058
1059
1060
1061
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionResponseChoice]
    usage: UsageInfo
1062
    prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
1063
1064


1065
class DeltaMessage(OpenAIBaseModel):
1066
1067
    role: Optional[str] = None
    content: Optional[str] = None
1068
    tool_calls: List[DeltaToolCall] = Field(default_factory=list)
1069
1070


1071
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
1072
1073
    index: int
    delta: DeltaMessage
1074
    logprobs: Optional[ChatCompletionLogProbs] = None
1075
    finish_reason: Optional[str] = None
1076
    stop_reason: Optional[Union[int, str]] = None
1077
1078


1079
class ChatCompletionStreamResponse(OpenAIBaseModel):
1080
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
1081
    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
1082
1083
1084
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: List[ChatCompletionResponseStreamChoice]
1085
    usage: Optional[UsageInfo] = Field(default=None)
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106


class BatchRequestInput(OpenAIBaseModel):
    """
    The per-line object of the batch input file.

    NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
    """

    # A developer-provided per-request id that will be used to match outputs to
    # inputs. Must be unique for each request in a batch.
    custom_id: str

    # The HTTP method to be used for the request. Currently only POST is
    # supported.
    method: str

    # The OpenAI API relative URL to be used for the request. Currently
    # /v1/chat/completions is supported.
    url: str

1107
    # The parameters of the request.
1108
    body: Union[ChatCompletionRequest, EmbeddingRequest]
1109
1110


1111
1112
1113
1114
1115
1116
1117
1118
class BatchResponseData(OpenAIBaseModel):
    # HTTP status code of the response.
    status_code: int = 200

    # An unique identifier for the API request.
    request_id: str

    # The body of the response.
1119
    body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
1120
1121


1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
class BatchRequestOutput(OpenAIBaseModel):
    """
    The per-line object of the batch output and error files
    """

    id: str

    # A developer-provided per-request id that will be used to match outputs to
    # inputs.
    custom_id: str

1133
    response: Optional[BatchResponseData]
1134
1135
1136
1137

    # For requests that failed with a non-HTTP error, this will contain more
    # information on the cause of the failure.
    error: Optional[Any]
1138
1139


1140
1141
1142
1143
class TokenizeCompletionRequest(OpenAIBaseModel):
    model: str
    prompt: str

1144
1145
1146
1147
1148
1149
    add_special_tokens: bool = Field(
        default=True,
        description=(
            "If true (the default), special tokens (e.g. BOS) will be added to "
            "the prompt."),
    )
1150
1151
1152
1153
1154
1155


class TokenizeChatRequest(OpenAIBaseModel):
    model: str
    messages: List[ChatCompletionMessageParam]

1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
    add_generation_prompt: bool = Field(
        default=True,
        description=
        ("If true, the generation prompt will be added to the chat template. "
         "This is a parameter used by chat template in tokenizer config of the "
         "model."),
    )
    continue_final_message: bool = Field(
        default=False,
        description=
        ("If this is set, the chat will be formatted so that the final "
         "message in the chat is open-ended, without any EOS tokens. The "
         "model will continue this message rather than starting a new one. "
         "This allows you to \"prefill\" part of the model's response for it. "
         "Cannot be used at the same time as `add_generation_prompt`."),
    )
    add_special_tokens: bool = Field(
        default=False,
        description=(
            "If true, special tokens (e.g. BOS) will be added to the prompt "
            "on top of what is added by the chat template. "
            "For most models, the chat template takes care of adding the "
            "special tokens so this should be set to false (as is the "
            "default)."),
    )
    chat_template: Optional[str] = Field(
        default=None,
        description=(
            "A Jinja template to use for this conversion. "
            "As of transformers v4.44, default chat template is no longer "
            "allowed, so you must provide a chat template if the tokenizer "
            "does not define one."),
    )
    chat_template_kwargs: Optional[Dict[str, Any]] = Field(
        default=None,
        description=("Additional kwargs to pass to the template renderer. "
                     "Will be accessible by the chat template."),
    )
1194

1195
1196
1197
1198
1199
1200
1201
1202
1203
    @model_validator(mode="before")
    @classmethod
    def check_generation_prompt(cls, data):
        if data.get("continue_final_message") and data.get(
                "add_generation_prompt"):
            raise ValueError("Cannot set both `continue_final_message` and "
                             "`add_generation_prompt` to True.")
        return data

1204
1205

TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
1206
1207
1208
1209
1210


class TokenizeResponse(OpenAIBaseModel):
    count: int
    max_model_len: int
1211
    tokens: List[int]
1212
1213
1214
1215
1216
1217
1218
1219
1220


class DetokenizeRequest(OpenAIBaseModel):
    model: str
    tokens: List[int]


class DetokenizeResponse(OpenAIBaseModel):
    prompt: str
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230


class LoadLoraAdapterRequest(BaseModel):
    lora_name: str
    lora_path: str


class UnloadLoraAdapterRequest(BaseModel):
    lora_name: str
    lora_int_id: Optional[int] = Field(default=None)