serving_embedding.py 8.89 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
import base64
5
import time
6
from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast
7

8
import numpy as np
9
from fastapi import Request
10
from typing_extensions import assert_never
11
12

from vllm.config import ModelConfig
13
from vllm.engine.protocol import EngineClient
14
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
15
from vllm.entrypoints.logger import RequestLogger
16
17
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
                                              EmbeddingRequest,
18
                                              EmbeddingResponse,
19
20
                                              EmbeddingResponseData,
                                              ErrorResponse, UsageInfo)
21
22
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
23
from vllm.logger import init_logger
24
25
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput)
26
from vllm.utils import merge_async_iterators
27
28
29
30

logger = init_logger(__name__)


31
def _get_embedding(
32
    output: EmbeddingOutput,
33
34
35
36
37
    encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
    if encoding_format == "float":
        return output.embedding
    elif encoding_format == "base64":
38
39
40
        # Force to use float32 for base64 encoding
        # to match the OpenAI python client behavior
        embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
41
42
43
44
45
        return base64.b64encode(embedding_bytes).decode("utf-8")

    assert_never(encoding_format)


46
47
class OpenAIServingEmbedding(OpenAIServing):

48
49
    def __init__(
        self,
50
        engine_client: EngineClient,
51
        model_config: ModelConfig,
52
        models: OpenAIServingModels,
53
54
        *,
        request_logger: Optional[RequestLogger],
55
        chat_template: Optional[str],
56
57
        chat_template_content_format: ChatTemplateContentFormatOption,
    ) -> None:
58
        super().__init__(engine_client=engine_client,
59
                         model_config=model_config,
60
                         models=models,
61
                         request_logger=request_logger)
62

63
64
        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
65

66
67
68
    async def create_embedding(
        self,
        request: EmbeddingRequest,
69
70
        raw_request: Optional[Request] = None,
    ) -> Union[EmbeddingResponse, ErrorResponse]:
71
72
        """
        Embedding API similar to OpenAI's API.
73
74
75
76
77
78
79
80

        See https://platform.openai.com/docs/api-reference/embeddings/create
        for the API specification. This API mimics the OpenAI Embedding API.
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

81
        encoding_format = request.encoding_format
82
83
84
85
86
        if request.dimensions is not None:
            return self.create_error_response(
                "dimensions is currently not supported")

        model_name = request.model
87
        request_id = f"embd-{self._base_request_id(raw_request)}"
88
        created_time = int(time.time())
89

90
91
92
93
94
95
96
97
98
99
100
        truncate_prompt_tokens = None

        if request.truncate_prompt_tokens is not None:
            if request.truncate_prompt_tokens <= self.max_model_len:
                truncate_prompt_tokens = request.truncate_prompt_tokens
            else:
                return self.create_error_response(
                    "truncate_prompt_tokens value is "
                    "greater than max_model_len."
                    " Please, select a smaller truncation size.")

101
        try:
102
103
104
105
106
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

107
            tokenizer = await self.engine_client.get_tokenizer(lora_request)
108

109
110
111
112
113
114
115
116
117
118
119
120
121
122
            if prompt_adapter_request is not None:
                raise NotImplementedError("Prompt adapter is not supported "
                                          "for embedding models")

            if isinstance(request, EmbeddingChatRequest):
                (
                    _,
                    request_prompts,
                    engine_prompts,
                ) = await self._preprocess_chat(
                    request,
                    tokenizer,
                    request.messages,
                    chat_template=request.chat_template or self.chat_template,
123
124
                    chat_template_content_format=self.
                    chat_template_content_format,
125
126
127
128
                    # In embedding requests, we are not generating tokens,
                    # so there is no need to append extra tokens to the input
                    add_generation_prompt=False,
                    continue_final_message=False,
129
130
131
132
                    truncate_prompt_tokens=truncate_prompt_tokens,
                    add_special_tokens=request.add_special_tokens,
                )
            else:
133
134
135
136
137
138
139
140
                (request_prompts,
                 engine_prompts) = await self._preprocess_completion(
                     request,
                     tokenizer,
                     request.input,
                     truncate_prompt_tokens=truncate_prompt_tokens,
                     add_special_tokens=request.add_special_tokens,
                 )
141
142
143
        except ValueError as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
144

145
        # Schedule the request and get the result generator.
146
        generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
147
148
        try:
            pooling_params = request.to_pooling_params()
149

150
            for i, engine_prompt in enumerate(engine_prompts):
151
152
153
                request_id_item = f"{request_id}-{i}"

                self._log_inputs(request_id_item,
154
                                 request_prompts[i],
155
156
157
158
                                 params=pooling_params,
                                 lora_request=lora_request,
                                 prompt_adapter_request=prompt_adapter_request)

159
160
                trace_headers = (None if raw_request is None else await
                                 self._get_trace_headers(raw_request.headers))
161

162
                generator = self.engine_client.encode(
163
                    engine_prompt,
164
                    pooling_params,
165
166
                    request_id_item,
                    lora_request=lora_request,
167
                    trace_headers=trace_headers,
168
                    priority=request.priority,
169
170
171
                )

                generators.append(generator)
172
173
174
175
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

176
        result_generator = merge_async_iterators(*generators)
177

178
179
        num_prompts = len(engine_prompts)

180
        # Non-streaming response
181
        final_res_batch: List[Optional[PoolingRequestOutput]]
182
        final_res_batch = [None] * num_prompts
183
184
185
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
186

187
            assert all(final_res is not None for final_res in final_res_batch)
188

189
            final_res_batch_checked = cast(List[PoolingRequestOutput],
190
191
                                           final_res_batch)

192
193
194
195
196
197
198
            response = self.request_output_to_embedding_response(
                final_res_batch_checked,
                request_id,
                created_time,
                model_name,
                encoding_format,
            )
199
200
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
201
202
203
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
204
205

        return response
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

    def request_output_to_embedding_response(
        self,
        final_res_batch: List[PoolingRequestOutput],
        request_id: str,
        created_time: int,
        model_name: str,
        encoding_format: Literal["float", "base64"],
    ) -> EmbeddingResponse:
        items: List[EmbeddingResponseData] = []
        num_prompt_tokens = 0

        for idx, final_res in enumerate(final_res_batch):
            embedding_res = EmbeddingRequestOutput.from_base(final_res)

            item = EmbeddingResponseData(
                index=idx,
                embedding=_get_embedding(embedding_res.outputs,
                                         encoding_format),
            )
            prompt_token_ids = final_res.prompt_token_ids

            items.append(item)
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

        return EmbeddingResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            data=items,
            usage=usage,
        )