serving_embedding.py 8.32 KB
Newer Older
1
import asyncio
2
import base64
3
import time
4
from typing import AsyncGenerator, List, Literal, Optional, Union, cast
5

6
import numpy as np
7
from fastapi import Request
8
from typing_extensions import assert_never
9
10

from vllm.config import ModelConfig
11
from vllm.engine.protocol import EngineClient
12
from vllm.entrypoints.chat_utils import load_chat_template
13
from vllm.entrypoints.logger import RequestLogger
14
15
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
                                              EmbeddingRequest,
16
                                              EmbeddingResponse,
17
18
                                              EmbeddingResponseData,
                                              ErrorResponse, UsageInfo)
19
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
20
from vllm.logger import init_logger
21
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
22
23
24
25
26
from vllm.utils import merge_async_iterators, random_uuid

logger = init_logger(__name__)


27
28
29
30
31
32
33
def _get_embedding(
    output: EmbeddingOutput,
    encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
    if encoding_format == "float":
        return output.embedding
    elif encoding_format == "base64":
34
35
36
        # Force to use float32 for base64 encoding
        # to match the OpenAI python client behavior
        embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
37
38
39
40
41
        return base64.b64encode(embedding_bytes).decode("utf-8")

    assert_never(encoding_format)


42
def request_output_to_embedding_response(
43
44
        final_res_batch: List[EmbeddingRequestOutput], request_id: str,
        created_time: int, model_name: str,
45
        encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
46
    data: List[EmbeddingResponseData] = []
47
48
49
    num_prompt_tokens = 0
    for idx, final_res in enumerate(final_res_batch):
        prompt_token_ids = final_res.prompt_token_ids
50
        embedding = _get_embedding(final_res.outputs, encoding_format)
51
        embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        data.append(embedding_data)

        num_prompt_tokens += len(prompt_token_ids)

    usage = UsageInfo(
        prompt_tokens=num_prompt_tokens,
        total_tokens=num_prompt_tokens,
    )

    return EmbeddingResponse(
        id=request_id,
        created=created_time,
        model=model_name,
        data=data,
        usage=usage,
    )


class OpenAIServingEmbedding(OpenAIServing):

72
73
    def __init__(
        self,
74
        engine_client: EngineClient,
75
        model_config: ModelConfig,
76
        base_model_paths: List[BaseModelPath],
77
78
        *,
        request_logger: Optional[RequestLogger],
79
        chat_template: Optional[str],
80
    ):
81
        super().__init__(engine_client=engine_client,
82
                         model_config=model_config,
83
                         base_model_paths=base_model_paths,
84
85
86
                         lora_modules=None,
                         prompt_adapters=None,
                         request_logger=request_logger)
87
88

        self.chat_template = load_chat_template(chat_template)
89

90
91
92
    async def create_embedding(
        self,
        request: EmbeddingRequest,
93
94
        raw_request: Optional[Request] = None,
    ) -> Union[EmbeddingResponse, ErrorResponse]:
95
96
        """
        Embedding API similar to OpenAI's API.
97
98
99
100
101
102
103
104

        See https://platform.openai.com/docs/api-reference/embeddings/create
        for the API specification. This API mimics the OpenAI Embedding API.
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

105
        encoding_format = request.encoding_format
106
107
108
109
110
        if request.dimensions is not None:
            return self.create_error_response(
                "dimensions is currently not supported")

        model_name = request.model
111
        request_id = f"embd-{random_uuid()}"
112
113
        created_time = int(time.monotonic())

114
115
116
117
118
119
120
121
122
123
124
        truncate_prompt_tokens = None

        if request.truncate_prompt_tokens is not None:
            if request.truncate_prompt_tokens <= self.max_model_len:
                truncate_prompt_tokens = request.truncate_prompt_tokens
            else:
                return self.create_error_response(
                    "truncate_prompt_tokens value is "
                    "greater than max_model_len."
                    " Please, select a smaller truncation size.")

125
        try:
126
127
128
129
130
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

131
            tokenizer = await self.engine_client.get_tokenizer(lora_request)
132

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
            if prompt_adapter_request is not None:
                raise NotImplementedError("Prompt adapter is not supported "
                                          "for embedding models")

            if isinstance(request, EmbeddingChatRequest):
                (
                    _,
                    request_prompts,
                    engine_prompts,
                ) = await self._preprocess_chat(
                    request,
                    tokenizer,
                    request.messages,
                    chat_template=request.chat_template or self.chat_template,
                    add_generation_prompt=request.add_generation_prompt,
                    continue_final_message=request.continue_final_message,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=request.add_special_tokens,
                )
            else:
                request_prompts, engine_prompts = self._preprocess_completion(
                    request,
                    tokenizer,
                    request.input,
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=request.add_special_tokens,
                )
        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
163

164
165
166
167
        # Schedule the request and get the result generator.
        generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
        try:
            pooling_params = request.to_pooling_params()
168

169
            for i, engine_prompt in enumerate(engine_prompts):
170
171
172
                request_id_item = f"{request_id}-{i}"

                self._log_inputs(request_id_item,
173
                                 request_prompts[i],
174
175
176
177
                                 params=pooling_params,
                                 lora_request=lora_request,
                                 prompt_adapter_request=prompt_adapter_request)

178
179
                trace_headers = (None if raw_request is None else await
                                 self._get_trace_headers(raw_request.headers))
180

181
                generator = self.engine_client.encode(
182
                    engine_prompt,
183
                    pooling_params,
184
185
                    request_id_item,
                    lora_request=lora_request,
186
                    trace_headers=trace_headers,
187
                    priority=request.priority,
188
189
190
                )

                generators.append(generator)
191
192
193
194
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

195
196
197
198
        result_generator = merge_async_iterators(
            *generators,
            is_cancelled=raw_request.is_disconnected if raw_request else None,
        )
199

200
201
        num_prompts = len(engine_prompts)

202
        # Non-streaming response
203
        final_res_batch: List[Optional[EmbeddingRequestOutput]]
204
        final_res_batch = [None] * num_prompts
205
206
207
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
208
209
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
210

211
        try:
212
213
214
215
216
217
            for final_res in final_res_batch:
                assert final_res is not None

            final_res_batch_checked = cast(List[EmbeddingRequestOutput],
                                           final_res_batch)

218
            response = request_output_to_embedding_response(
219
                final_res_batch_checked, request_id, created_time, model_name,
220
                encoding_format)
221
222
223
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
224
225

        return response