protocol.py 3.64 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
4
from typing import TypeAlias
5

6
from pydantic import Field
7

8
from vllm import PoolingParams
9
from vllm.config import ModelConfig
10
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
11
from vllm.entrypoints.pooling.base.protocol import (
12
    ChatRequestMixin,
13
    CompletionRequestMixin,
14
    EmbedRequestMixin,
15
16
    PoolingBasicRequestMixin,
)
17
from vllm.renderers import TokenizeParams
18
19
20
from vllm.utils import random_uuid


21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def _get_max_total_output_tokens(
    model_config: ModelConfig,
) -> tuple[int | None, int]:
    max_total_tokens = model_config.max_model_len
    pooler_config = model_config.pooler_config

    if pooler_config is None:
        return max_total_tokens, 0

    if pooler_config.enable_chunked_processing:
        return None, 0

    max_embed_len = pooler_config.max_embed_len or max_total_tokens
    max_output_tokens = max_total_tokens - max_embed_len
    return max_total_tokens, max_output_tokens


38
39
40
class EmbeddingCompletionRequest(
    PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin
):
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        encoder_config = model_config.encoder_config or {}

        (
            max_total_tokens,
            max_output_tokens,
        ) = _get_max_total_output_tokens(model_config)

        return TokenizeParams(
            max_total_tokens=max_total_tokens,
            max_output_tokens=max_output_tokens,
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=self.add_special_tokens,
            max_total_tokens_param="max_model_len",
            max_output_tokens_param="max_model_len - max_embed_len",
        )
58

59
60
61
62
63
64
65
    def to_pooling_params(self):
        return PoolingParams(
            task="embed",
            dimensions=self.dimensions,
            use_activation=self.use_activation,
        )

66

67
68
69
class EmbeddingChatRequest(
    PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin
):
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        encoder_config = model_config.encoder_config or {}

        (
            max_total_tokens,
            max_output_tokens,
        ) = _get_max_total_output_tokens(model_config)

        return TokenizeParams(
            max_total_tokens=max_total_tokens,
            max_output_tokens=max_output_tokens,
            truncate_prompt_tokens=self.truncate_prompt_tokens,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=self.add_special_tokens,
            max_total_tokens_param="max_model_len",
            max_output_tokens_param="max_model_len - max_embed_len",
        )

88
89
90
91
92
93
94
    def to_pooling_params(self):
        return PoolingParams(
            task="embed",
            dimensions=self.dimensions,
            use_activation=self.use_activation,
        )

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest


class EmbeddingResponseData(OpenAIBaseModel):
    index: int
    object: str = "embedding"
    embedding: list[float] | str


class EmbeddingResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[EmbeddingResponseData]
    usage: UsageInfo


class EmbeddingBytesResponse(OpenAIBaseModel):
115
116
    content: list[bytes]
    headers: dict[str, str] | None = None
117
    media_type: str = "application/octet-stream"