protocol.py 25.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
6
import json
Zhuohan Li's avatar
Zhuohan Li committed
7
import time
8
from typing import Annotated, Any, ClassVar, Literal, TypeAlias
Zhuohan Li's avatar
Zhuohan Li committed
9

10
import regex as re
11
import torch
12
13
14
15
16
17
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    model_validator,
)
Zhuohan Li's avatar
Zhuohan Li committed
18

19
from vllm.entrypoints.chat_utils import make_tool_call_id
20
from vllm.exceptions import VLLMValidationError
21
from vllm.logger import init_logger
22
from vllm.logprobs import Logprob
23
24
25
26
27
28
from vllm.sampling_params import (
    BeamSearchParams,
    RequestOutputKind,
    SamplingParams,
    StructuredOutputsParams,
)
29
30
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
31

32
33
logger = init_logger(__name__)

34
_LONG_INFO = torch.iinfo(torch.long)
35

Zhuohan Li's avatar
Zhuohan Li committed
36

37
class OpenAIBaseModel(BaseModel):
38
39
40
    # OpenAI API does allow extra fields
    model_config = ConfigDict(extra="allow")

41
    # Cache class field names
42
    field_names: ClassVar[set[str] | None] = None
43

44
    @model_validator(mode="wrap")
45
    @classmethod
46
47
48
49
    def __log_extra_fields__(cls, data, handler):
        result = handler(data)
        if not isinstance(data, dict):
            return result
50
51
        field_names = cls.field_names
        if field_names is None:
52
53
54
55
            # Get all class field names and their potential aliases
            field_names = set()
            for field_name, field in cls.model_fields.items():
                field_names.add(field_name)
56
                if alias := getattr(field, "alias", None):
57
58
59
60
61
62
                    field_names.add(alias)
            cls.field_names = field_names

        # Compare against both field names and aliases
        if any(k not in field_names for k in data):
            logger.warning(
63
                "The following fields were present in the request but ignored: %s",
64
65
                data.keys() - field_names,
            )
66
        return result
67
68


69
class ErrorInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
70
71
    message: str
    type: str
72
    param: str | None = None
73
    code: int
Zhuohan Li's avatar
Zhuohan Li committed
74
75


76
77
78
79
class ErrorResponse(OpenAIBaseModel):
    error: ErrorInfo


80
class ModelPermission(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
81
82
83
84
85
86
87
88
89
90
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
    object: str = "model_permission"
    created: int = Field(default_factory=lambda: int(time.time()))
    allow_create_engine: bool = False
    allow_sampling: bool = True
    allow_logprobs: bool = True
    allow_search_indices: bool = False
    allow_view: bool = True
    allow_fine_tuning: bool = False
    organization: str = "*"
91
    group: str | None = None
92
    is_blocking: bool = False
Zhuohan Li's avatar
Zhuohan Li committed
93
94


95
class ModelCard(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
96
97
98
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
Woosuk Kwon's avatar
Woosuk Kwon committed
99
    owned_by: str = "vllm"
100
101
102
    root: str | None = None
    parent: str | None = None
    max_model_len: int | None = None
103
    permission: list[ModelPermission] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
104
105


106
class ModelList(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
107
    object: str = "list"
108
    data: list[ModelCard] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
109
110


111
class PromptTokenUsageInfo(OpenAIBaseModel):
112
    cached_tokens: int | None = None
113
114


115
class UsageInfo(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
116
117
    prompt_tokens: int = 0
    total_tokens: int = 0
118
119
    completion_tokens: int | None = 0
    prompt_tokens_details: PromptTokenUsageInfo | None = None
Zhuohan Li's avatar
Zhuohan Li committed
120
121


122
123
class RequestResponseMetadata(BaseModel):
    request_id: str
124
    final_usage_info: UsageInfo | None = None
125
126


127
128
class JsonSchemaResponseFormat(OpenAIBaseModel):
    name: str
129
    description: str | None = None
130
131
    # schema is the field in openai but that causes conflicts with pydantic so
    # instead use json_schema with an alias
132
133
    json_schema: dict[str, Any] | None = Field(default=None, alias="schema")
    strict: bool | None = None
134
135


136
class LegacyStructuralTag(OpenAIBaseModel):
137
138
139
    begin: str
    # schema is the field, but that causes conflicts with pydantic so
    # instead use structural_tag_schema with an alias
140
    structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema")
141
142
143
    end: str


144
class LegacyStructuralTagResponseFormat(OpenAIBaseModel):
145
    type: Literal["structural_tag"]
146
    structures: list[LegacyStructuralTag]
147
148
149
    triggers: list[str]


150
151
152
153
154
155
156
157
158
159
class StructuralTagResponseFormat(OpenAIBaseModel):
    type: Literal["structural_tag"]
    format: Any


AnyStructuralTagResponseFormat: TypeAlias = (
    LegacyStructuralTagResponseFormat | StructuralTagResponseFormat
)


160
class ResponseFormat(OpenAIBaseModel):
161
    # type must be "json_schema", "json_object", or "text"
162
    type: Literal["text", "json_object", "json_schema"]
163
    json_schema: JsonSchemaResponseFormat | None = None
164
165


166
167
168
AnyResponseFormat: TypeAlias = (
    ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat
)
169
170


171
class StreamOptions(OpenAIBaseModel):
172
173
    include_usage: bool | None = True
    continuous_usage_stats: bool | None = False
174
175


176
177
class FunctionDefinition(OpenAIBaseModel):
    name: str
178
179
    description: str | None = None
    parameters: dict[str, Any] | None = None
180
181


182
183
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
184
185
class LogitsProcessorConstructor(BaseModel):
    qualname: str
186
187
    args: list[Any] | None = None
    kwargs: dict[str, Any] | None = None
188

189
190
    model_config = ConfigDict(extra="forbid")

191

192
LogitsProcessors = list[str | LogitsProcessorConstructor]
193
194


195
def get_logits_processors(
196
197
    processors: LogitsProcessors | None, pattern: str | None
) -> list[Any] | None:
198
199
200
    if processors and pattern:
        logits_processors = []
        for processor in processors:
201
            qualname = processor if isinstance(processor, str) else processor.qualname
202
203
204
205
            if not re.match(pattern, qualname):
                raise ValueError(
                    f"Logits processor '{qualname}' is not allowed by this "
                    "server. See --logits-processor-pattern engine argument "
206
207
                    "for more information."
                )
208
209
210
211
212
213
214
            try:
                logits_processor = resolve_obj_by_qualname(qualname)
            except Exception as e:
                raise ValueError(
                    f"Logits processor '{qualname}' could not be resolved: {e}"
                ) from e
            if isinstance(processor, LogitsProcessorConstructor):
215
216
217
                logits_processor = logits_processor(
                    *processor.args or [], **processor.kwargs or {}
                )
218
219
220
221
222
            logits_processors.append(logits_processor)
        return logits_processors
    elif processors:
        raise ValueError(
            "The `logits_processors` argument is not supported by this "
223
            "server. See --logits-processor-pattern engine argument "
224
225
            "for more information."
        )
226
227
228
    return None


229
class CompletionRequest(OpenAIBaseModel):
230
231
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/completions/create
232
233
234
235
236
237
238
    model: str | None = None
    prompt: list[int] | list[list[int]] | str | list[str] | None = None
    echo: bool | None = False
    frequency_penalty: float | None = 0.0
    logit_bias: dict[str, float] | None = None
    logprobs: int | None = None
    max_tokens: int | None = 16
239
    n: int = 1
240
241
242
243
244
245
246
247
248
    presence_penalty: float | None = 0.0
    seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
    stop: str | list[str] | None = []
    stream: bool | None = False
    stream_options: StreamOptions | None = None
    suffix: str | None = None
    temperature: float | None = None
    top_p: float | None = None
    user: str | None = None
249

250
    # --8<-- [start:completion-sampling-params]
251
    use_beam_search: bool = False
252
253
254
    top_k: int | None = None
    min_p: float | None = None
    repetition_penalty: float | None = None
255
    length_penalty: float = 1.0
256
    stop_token_ids: list[int] | None = []
257
258
259
260
261
    include_stop_str_in_output: bool = False
    ignore_eos: bool = False
    min_tokens: int = 0
    skip_special_tokens: bool = True
    spaces_between_special_tokens: bool = True
262
263
264
    truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_LONG_INFO.max)] | None = (
        None
    )
265
266
    allowed_token_ids: list[int] | None = None
    prompt_logprobs: int | None = None
267
    # --8<-- [end:completion-sampling-params]
268

269
    # --8<-- [start:completion-extra-params]
270
    prompt_embeds: bytes | list[bytes] | None = None
271
272
    add_special_tokens: bool = Field(
        default=True,
273
        description=(
274
            "If true (the default), special tokens (e.g. BOS) will be added to "
275
276
            "the prompt."
        ),
277
    )
278
    response_format: AnyResponseFormat | None = Field(
279
        default=None,
280
281
282
283
284
        description=(
            "Similar to chat completion, this parameter specifies the format "
            "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
            ", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
        ),
285
    )
286
    structured_outputs: StructuredOutputsParams | None = Field(
287
        default=None,
288
        description="Additional kwargs for structured outputs",
289
    )
290
291
292
293
294
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
295
296
            "if the served model does not use priority scheduling."
        ),
297
    )
298
    request_id: str = Field(
299
        default_factory=random_uuid,
300
301
302
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
303
304
            "through out the inference process and return in response."
        ),
305
    )
306
    logits_processors: LogitsProcessors | None = Field(
307
308
309
310
311
312
313
314
315
        default=None,
        description=(
            "A list of either qualified names of logits processors, or "
            "constructor objects, to apply when sampling. A constructor is "
            "a JSON object with a required 'qualname' field specifying the "
            "qualified name of the processor class/factory, and optional "
            "'args' and 'kwargs' fields containing positional and keyword "
            "arguments. For example: {'qualname': "
            "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
316
317
318
            "{'param': 'value'}}."
        ),
    )
319

320
    return_tokens_as_token_ids: bool | None = Field(
321
322
323
324
        default=None,
        description=(
            "If specified with 'logprobs', tokens are represented "
            " as strings of the form 'token_id:{token_id}' so that tokens "
325
326
327
            "that are not JSON-encodable can be identified."
        ),
    )
328
    return_token_ids: bool | None = Field(
329
330
331
332
333
334
        default=None,
        description=(
            "If specified, the result will include token IDs alongside the "
            "generated text. In streaming mode, prompt_token_ids is included "
            "only in the first chunk, and token_ids contains the delta tokens "
            "for each chunk. This is useful for debugging or when you "
335
336
337
            "need to map generated text back to input tokens."
        ),
    )
338

339
    cache_salt: str | None = Field(
340
341
342
343
344
345
346
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
347
            "to 256 bit)."
348
349
        ),
    )
350

351
    kv_transfer_params: dict[str, Any] | None = Field(
Robert Shaw's avatar
Robert Shaw committed
352
        default=None,
353
354
        description="KVTransfer parameters used for disaggregated serving.",
    )
Robert Shaw's avatar
Robert Shaw committed
355

356
    vllm_xargs: dict[str, str | int | float] | None = Field(
357
        default=None,
358
359
360
361
        description=(
            "Additional request parameters with string or "
            "numeric values, used by custom extensions."
        ),
362
363
    )

364
    # --8<-- [end:completion-extra-params]
Zhuohan Li's avatar
Zhuohan Li committed
365

366
367
368
369
370
    # Default sampling parameters for completion requests
    _DEFAULT_SAMPLING_PARAMS: dict = {
        "repetition_penalty": 1.0,
        "temperature": 1.0,
        "top_p": 1.0,
371
        "top_k": 0,
372
373
374
375
        "min_p": 0.0,
    }

    def to_beam_search_params(
376
377
        self,
        max_tokens: int,
378
        default_sampling_params: dict | None = None,
379
380
381
    ) -> BeamSearchParams:
        if default_sampling_params is None:
            default_sampling_params = {}
382
        n = self.n if self.n is not None else 1
383
384
385

        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get("temperature", 1.0)
386
387
388
389
390
391

        return BeamSearchParams(
            beam_width=n,
            max_tokens=max_tokens,
            ignore_eos=self.ignore_eos,
            temperature=temperature,
392
            length_penalty=self.length_penalty,
393
394
            include_stop_str_in_output=self.include_stop_str_in_output,
        )
395

396
    def to_sampling_params(
397
        self,
398
        max_tokens: int,
399
400
        logits_processor_pattern: str | None,
        default_sampling_params: dict | None = None,
401
    ) -> SamplingParams:
402
403
        if default_sampling_params is None:
            default_sampling_params = {}
404

405
406
407
408
409
410
411
412
        # Default parameters
        if (repetition_penalty := self.repetition_penalty) is None:
            repetition_penalty = default_sampling_params.get(
                "repetition_penalty",
                self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
            )
        if (temperature := self.temperature) is None:
            temperature = default_sampling_params.get(
413
414
                "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
            )
415
416
        if (top_p := self.top_p) is None:
            top_p = default_sampling_params.get(
417
418
                "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
            )
419
420
        if (top_k := self.top_k) is None:
            top_k = default_sampling_params.get(
421
422
                "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
            )
423
424
        if (min_p := self.min_p) is None:
            min_p = default_sampling_params.get(
425
426
                "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
            )
427

428
429
430
431
        prompt_logprobs = self.prompt_logprobs
        if prompt_logprobs is None and self.echo:
            prompt_logprobs = self.logprobs

432
433
        echo_without_generation = self.echo and self.max_tokens == 0

434
435
436
437
438
439
440
441
442
443
444
445
        response_format = self.response_format
        if response_format is not None:
            # If structured outputs wasn't already enabled,
            # we must enable it for these features to work
            if self.structured_outputs is None:
                self.structured_outputs = StructuredOutputsParams()

            # Set structured output params for response format
            if response_format.type == "json_object":
                self.structured_outputs.json_object = True
            elif response_format.type == "json_schema":
                json_schema = response_format.json_schema
446
                assert json_schema is not None
447
448
449
                self.structured_outputs.json = json_schema.json_schema
            elif response_format.type == "structural_tag":
                structural_tag = response_format
450
                assert structural_tag is not None and isinstance(
451
452
453
454
455
                    structural_tag,
                    (
                        LegacyStructuralTagResponseFormat,
                        StructuralTagResponseFormat,
                    ),
456
457
                )
                s_tag_obj = structural_tag.model_dump(by_alias=True)
458
                self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
459

460
461
462
463
        extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
        if self.kv_transfer_params:
            # Pass in kv_transfer_params via extra_args
            extra_args["kv_transfer_params"] = self.kv_transfer_params
464
        return SamplingParams.from_optional(
465
466
467
            n=self.n,
            presence_penalty=self.presence_penalty,
            frequency_penalty=self.frequency_penalty,
468
469
470
471
472
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
Nick Hill's avatar
Nick Hill committed
473
            seed=self.seed,
474
475
            stop=self.stop,
            stop_token_ids=self.stop_token_ids,
476
            logprobs=self.logprobs,
477
            ignore_eos=self.ignore_eos,
478
            max_tokens=max_tokens if not echo_without_generation else 1,
479
            min_tokens=self.min_tokens,
480
            prompt_logprobs=prompt_logprobs,
481
            skip_special_tokens=self.skip_special_tokens,
482
            spaces_between_special_tokens=self.spaces_between_special_tokens,
483
            include_stop_str_in_output=self.include_stop_str_in_output,
484
485
486
            logits_processors=get_logits_processors(
                self.logits_processors, logits_processor_pattern
            ),
487
            truncate_prompt_tokens=self.truncate_prompt_tokens,
488
489
490
            output_kind=RequestOutputKind.DELTA
            if self.stream
            else RequestOutputKind.FINAL_ONLY,
491
            structured_outputs=self.structured_outputs,
492
            logit_bias=self.logit_bias,
Robert Shaw's avatar
Robert Shaw committed
493
            allowed_token_ids=self.allowed_token_ids,
494
            extra_args=extra_args or None,
495
            skip_clone=True,  # Created fresh per request, safe to skip clone
496
        )
497

498
499
    @model_validator(mode="before")
    @classmethod
500
    def check_structured_outputs_count(cls, data):
501
        if data.get("structured_outputs", None) is None:
502
503
            return data

504
        structured_outputs_kwargs = data["structured_outputs"]
505
506
        count = sum(
            structured_outputs_kwargs.get(k) is not None
507
508
            for k in ("json", "regex", "choice")
        )
509
        if count > 1:
510
            raise VLLMValidationError(
511
                "You can only use one kind of constraints for structured "
512
513
                "outputs ('json', 'regex' or 'choice').",
                parameter="structured_outputs",
514
            )
515
516
        return data

517
518
519
    @model_validator(mode="before")
    @classmethod
    def check_logprobs(cls, data):
520
        if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
521
            if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
522
523
524
                raise VLLMValidationError(
                    "`prompt_logprobs` are not available when `stream=True`.",
                    parameter="prompt_logprobs",
525
                )
526

527
            if prompt_logprobs < 0 and prompt_logprobs != -1:
528
529
530
531
532
                raise VLLMValidationError(
                    "`prompt_logprobs` must be a positive value or -1.",
                    parameter="prompt_logprobs",
                    value=prompt_logprobs,
                )
533
        if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
534
535
536
537
538
            raise VLLMValidationError(
                "`logprobs` must be a positive value.",
                parameter="logprobs",
                value=logprobs,
            )
539

540
541
        return data

542
543
544
545
    @model_validator(mode="before")
    @classmethod
    def validate_stream_options(cls, data):
        if data.get("stream_options") and not data.get("stream"):
546
547
548
549
            raise VLLMValidationError(
                "Stream options can only be defined when `stream=True`.",
                parameter="stream_options",
            )
550

551
552
        return data

553
554
555
    @model_validator(mode="before")
    @classmethod
    def validate_prompt_and_prompt_embeds(cls, data):
556
557
558
        prompt = data.get("prompt")
        prompt_embeds = data.get("prompt_embeds")

559
560
561
562
        prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
        embeds_is_empty = prompt_embeds is None or (
            isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
        )
563
564

        if prompt_is_empty and embeds_is_empty:
565
            raise ValueError(
566
567
568
                "Either prompt or prompt_embeds must be provided and non-empty."
            )

569
570
        return data

571
572
573
    @model_validator(mode="before")
    @classmethod
    def check_cache_salt_support(cls, data):
574
575
576
577
578
579
        if data.get("cache_salt") is not None and (
            not isinstance(data["cache_salt"], str) or not data["cache_salt"]
        ):
            raise ValueError(
                "Parameter 'cache_salt' must be a non-empty string if provided."
            )
580
581
        return data

Zhuohan Li's avatar
Zhuohan Li committed
582

583
class CompletionLogProbs(OpenAIBaseModel):
584
    text_offset: list[int] = Field(default_factory=list)
585
    token_logprobs: list[float | None] = Field(default_factory=list)
586
    tokens: list[str] = Field(default_factory=list)
587
    top_logprobs: list[dict[str, float] | None] = Field(default_factory=list)
Zhuohan Li's avatar
Zhuohan Li committed
588
589


590
class CompletionResponseChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
591
592
    index: int
    text: str
593
594
595
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
596
597
598
599
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
600
601
            "including encountering the EOS token"
        ),
602
    )
603
604
605
    token_ids: list[int] | None = None  # For response
    prompt_logprobs: list[dict[int, Logprob] | None] | None = None
    prompt_token_ids: list[int] | None = None  # For prompt
Zhuohan Li's avatar
Zhuohan Li committed
606
607


608
class CompletionResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
609
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
610
    object: Literal["text_completion"] = "text_completion"
Zhuohan Li's avatar
Zhuohan Li committed
611
612
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
613
    choices: list[CompletionResponseChoice]
614
615
    service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
    system_fingerprint: str | None = None
Zhuohan Li's avatar
Zhuohan Li committed
616
    usage: UsageInfo
617
618

    # vLLM-specific fields that are not in OpenAI spec
619
    kv_transfer_params: dict[str, Any] | None = Field(
620
621
        default=None, description="KVTransfer parameters."
    )
Zhuohan Li's avatar
Zhuohan Li committed
622
623


624
class CompletionResponseStreamChoice(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
625
626
    index: int
    text: str
627
628
629
    logprobs: CompletionLogProbs | None = None
    finish_reason: str | None = None
    stop_reason: int | str | None = Field(
630
631
632
633
        default=None,
        description=(
            "The stop string or token id that caused the completion "
            "to stop, None if the completion finished for some other reason "
634
635
            "including encountering the EOS token"
        ),
636
    )
637
638
    # not part of the OpenAI spec but for tracing the tokens
    # prompt tokens is put into choice to align with CompletionResponseChoice
639
640
    prompt_token_ids: list[int] | None = None
    token_ids: list[int] | None = None
Zhuohan Li's avatar
Zhuohan Li committed
641
642


643
class CompletionStreamResponse(OpenAIBaseModel):
Zhuohan Li's avatar
Zhuohan Li committed
644
645
646
647
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
648
    choices: list[CompletionResponseStreamChoice]
649
    usage: UsageInfo | None = Field(default=None)
650
651


652
653
654
655
656
657
class FunctionCall(OpenAIBaseModel):
    name: str
    arguments: str


class ToolCall(OpenAIBaseModel):
658
    id: str = Field(default_factory=make_tool_call_id)
659
660
661
662
    type: Literal["function"] = "function"
    function: FunctionCall


663
class DeltaFunctionCall(BaseModel):
664
665
    name: str | None = None
    arguments: str | None = None
666
667
668
669


# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
670
671
    id: str | None = None
    type: Literal["function"] | None = None
672
    index: int
673
    function: DeltaFunctionCall | None = None
674
675
676
677
678
679
680


class ExtractedToolCallInformation(BaseModel):
    # indicate if tools were called
    tools_called: bool

    # extracted tool calls
681
    tool_calls: list[ToolCall]
682
683
684

    # content - per OpenAI spec, content AND tool calls can be returned rarely
    # But some models will do this intentionally
685
    content: str | None = None
686
687


688
class DeltaMessage(OpenAIBaseModel):
689
690
    role: str | None = None
    content: str | None = None
691
    reasoning: str | None = None
692
    reasoning_content: str | None = None
693
    """Deprecated: use `reasoning` instead."""
694
    tool_calls: list[DeltaToolCall] = Field(default_factory=list)
695

696
697
698
699
700
701
    @model_validator(mode="after")
    def handle_deprecated_reasoning_content(self):
        """Copy reasoning to reasoning_content for backward compatibility."""
        self.reasoning_content = self.reasoning
        return self

702

703
704
705
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
    request_id: str = Field(
706
        default_factory=random_uuid,
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
        description=(
            "The request_id related to this request. If the caller does "
            "not set it, a random_uuid will be generated. This id is used "
            "through out the inference process and return in response."
        ),
    )
    token_ids: list[int]
    """The token ids to generate text from."""

    # features: MultiModalFeatureSpec
    # TODO (NickLucche): implement once Renderer work is completed
    features: str | None = None
    """The processed MM inputs for the model."""

    sampling_params: SamplingParams
    """The sampling parameters for the model."""

    model: str | None = None

    stream: bool | None = False
    stream_options: StreamOptions | None = None
    cache_salt: str | None = Field(
        default=None,
        description=(
            "If specified, the prefix cache will be salted with the provided "
            "string to prevent an attacker to guess prompts in multi-user "
            "environments. The salt should be random, protected from "
            "access by 3rd parties, and long enough to be "
            "unpredictable (e.g., 43 characters base64-encoded, corresponding "
            "to 256 bit)."
        ),
    )
    priority: int = Field(
        default=0,
        description=(
            "The priority of the request (lower means earlier handling; "
            "default: 0). Any priority other than 0 will raise an error "
            "if the served model does not use priority scheduling."
        ),
    )
    kv_transfer_params: dict[str, Any] | None = Field(
        default=None,
        description="KVTransfer parameters used for disaggregated serving.",
    )