serving_embedding.py 6.67 KB
Newer Older
1
import asyncio
2
import base64
3
import time
4
5
from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple,
                    Union, cast)
6

7
import numpy as np
8
9
10
from fastapi import Request

from vllm.config import ModelConfig
11
from vllm.engine.protocol import AsyncEngineClient
12
from vllm.entrypoints.logger import RequestLogger
13
14
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
                                              EmbeddingResponse,
15
16
                                              EmbeddingResponseData,
                                              ErrorResponse, UsageInfo)
17
18
19
20
21
22
23
24
25
26
27
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid

logger = init_logger(__name__)

TypeTokenIDs = List[int]


def request_output_to_embedding_response(
28
29
30
        final_res_batch: List[EmbeddingRequestOutput], request_id: str,
        created_time: int, model_name: str,
        encoding_format: str) -> EmbeddingResponse:
31
    data: List[EmbeddingResponseData] = []
32
33
34
    num_prompt_tokens = 0
    for idx, final_res in enumerate(final_res_batch):
        prompt_token_ids = final_res.prompt_token_ids
35
36
        embedding = final_res.outputs.embedding
        if encoding_format == "base64":
37
38
            embedding_bytes = np.array(embedding).tobytes()
            embedding = base64.b64encode(embedding_bytes).decode("utf-8")
39
        embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        data.append(embedding_data)

        num_prompt_tokens += len(prompt_token_ids)

    usage = UsageInfo(
        prompt_tokens=num_prompt_tokens,
        total_tokens=num_prompt_tokens,
    )

    return EmbeddingResponse(
        id=request_id,
        created=created_time,
        model=model_name,
        data=data,
        usage=usage,
    )


class OpenAIServingEmbedding(OpenAIServing):

60
61
    def __init__(
        self,
62
        async_engine_client: AsyncEngineClient,
63
64
65
66
67
        model_config: ModelConfig,
        served_model_names: List[str],
        *,
        request_logger: Optional[RequestLogger],
    ):
68
        super().__init__(async_engine_client=async_engine_client,
69
70
                         model_config=model_config,
                         served_model_names=served_model_names,
71
72
73
                         lora_modules=None,
                         prompt_adapters=None,
                         request_logger=request_logger)
74
        self._enabled = self._check_embedding_mode(model_config.embedding_mode)
75

76
77
78
79
80
    async def create_embedding(
        self,
        request: EmbeddingRequest,
        raw_request: Optional[Request] = None
    ) -> Union[ErrorResponse, EmbeddingResponse]:
81
82
83
84
85
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/embeddings/create
        for the API specification. This API mimics the OpenAI Embedding API.
        """
86
87
        if not self._enabled:
            return self.create_error_response("Embedding API disabled")
88
89
90
91
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

92
93
        encoding_format = (request.encoding_format
                           if request.encoding_format else "float")
94
95
96
97
98
        if request.dimensions is not None:
            return self.create_error_response(
                "dimensions is currently not supported")

        model_name = request.model
99
        request_id = f"embd-{random_uuid()}"
100
101
102
        created_time = int(time.monotonic())

        # Schedule the request and get the result generator.
103
        generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
104
        try:
105
106
107
108
109
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

110
111
            tokenizer = await self.async_engine_client.get_tokenizer(
                lora_request)
112

113
114
            pooling_params = request.to_pooling_params()

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
            prompts = list(
                self._tokenize_prompt_input_or_inputs(
                    request,
                    tokenizer,
                    request.input,
                ))

            for i, prompt_inputs in enumerate(prompts):
                request_id_item = f"{request_id}-{i}"

                self._log_inputs(request_id_item,
                                 prompt_inputs,
                                 params=pooling_params,
                                 lora_request=lora_request,
                                 prompt_adapter_request=prompt_adapter_request)

                if prompt_adapter_request is not None:
                    raise NotImplementedError(
                        "Prompt adapter is not supported "
                        "for embedding models")
135

136
                generator = self.async_engine_client.encode(
137
                    {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
138
                    pooling_params,
139
140
                    request_id_item,
                    lora_request=lora_request,
141
142
143
                )

                generators.append(generator)
144
145
146
147
148
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

        result_generator: AsyncIterator[Tuple[
149
            int, EmbeddingRequestOutput]] = merge_async_iterators(
150
151
152
                *generators,
                is_cancelled=raw_request.is_disconnected
                if raw_request else None)
153
154

        # Non-streaming response
155
156
157
158
159
        final_res_batch: List[Optional[EmbeddingRequestOutput]]
        final_res_batch = [None] * len(prompts)
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res
160
161
162
163
164
165
166

            for final_res in final_res_batch:
                assert final_res is not None

            final_res_batch_checked = cast(List[EmbeddingRequestOutput],
                                           final_res_batch)

167
            response = request_output_to_embedding_response(
168
                final_res_batch_checked, request_id, created_time, model_name,
169
                encoding_format)
170
171
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
172
173
174
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
175
176
177
178
179
180
181
182
183

        return response

    def _check_embedding_mode(self, embedding_mode: bool):
        if not embedding_mode:
            logger.warning(
                "embedding_mode is False. Embedding API will not work.")
        else:
            logger.info("Activating the server engine with embedding enabled.")
184
        return embedding_mode