protocol.py 3.91 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
4
from typing import Any, TypeAlias
5
6
7
8
9
10

from pydantic import (
    Field,
)

from vllm import PoolingParams
11
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
12
from vllm.entrypoints.pooling.base.protocol import (
13
    ChatRequestMixin,
14
15
16
    CompletionRequestMixin,
    PoolingBasicRequestMixin,
)
17
18
19
20
from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness


21
class EmbeddingCompletionRequest(PoolingBasicRequestMixin, CompletionRequestMixin):
22
23
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/embeddings
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    encoding_format: EncodingFormat = "float"
    dimensions: int | None = None

    # --8<-- [start:embedding-extra-params]
    normalize: bool | None = Field(
        default=None,
        description="Whether to normalize the embeddings outputs. Default is True.",
    )
    embed_dtype: EmbedDType = Field(
        default="float32",
        description=(
            "What dtype to use for encoding. Default to using float32 for base64 "
            "encoding to match the OpenAI python client behavior. "
            "This parameter will affect base64 and binary_response."
        ),
    )
    endianness: Endianness = Field(
        default="native",
        description=(
            "What endianness to use for encoding. Default to using native for "
            "base64 encoding to match the OpenAI python client behavior."
            "This parameter will affect base64 and binary_response."
        ),
    )
    # --8<-- [end:embedding-extra-params]

    def to_pooling_params(self):
        return PoolingParams(
            dimensions=self.dimensions,
54
            use_activation=self.normalize,
55
            truncate_prompt_tokens=self.truncate_prompt_tokens,
56
57
58
        )


59
class EmbeddingChatRequest(PoolingBasicRequestMixin, ChatRequestMixin):
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    encoding_format: EncodingFormat = "float"
    dimensions: int | None = None

    # --8<-- [start:chat-embedding-extra-params]
    mm_processor_kwargs: dict[str, Any] | None = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )
    normalize: bool | None = Field(
        default=None,
        description="Whether to normalize the embeddings outputs. Default is True.",
    )
    embed_dtype: EmbedDType = Field(
        default="float32",
        description=(
            "What dtype to use for encoding. Default to using float32 for base64 "
            "encoding to match the OpenAI python client behavior. "
            "This parameter will affect base64 and binary_response."
        ),
    )
    endianness: Endianness = Field(
        default="native",
        description=(
            "What endianness to use for encoding. Default to using native for "
            "base64 encoding to match the OpenAI python client behavior."
            "This parameter will affect base64 and binary_response."
        ),
    )
    # --8<-- [end:chat-embedding-extra-params]

    def to_pooling_params(self):
        return PoolingParams(
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            dimensions=self.dimensions,
94
            use_activation=self.normalize,
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        )


EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest


class EmbeddingResponseData(OpenAIBaseModel):
    index: int
    object: str = "embedding"
    embedding: list[float] | str


class EmbeddingResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[EmbeddingResponseData]
    usage: UsageInfo


class EmbeddingBytesResponse(OpenAIBaseModel):
117
118
    content: list[bytes]
    headers: dict[str, str] | None = None
119
    media_type: str = "application/octet-stream"