serving_embedding.py 6.39 KB
Newer Older
1
import base64
2
import time
3
from typing import AsyncIterator, List, Optional, Tuple, cast
4

5
import numpy as np
6
7
8
from fastapi import Request

from vllm.config import ModelConfig
9
from vllm.engine.protocol import AsyncEngineClient
10
from vllm.entrypoints.logger import RequestLogger
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
                                              EmbeddingResponse,
                                              EmbeddingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid

logger = init_logger(__name__)

TypeTokenIDs = List[int]


def request_output_to_embedding_response(
25
26
27
        final_res_batch: List[EmbeddingRequestOutput], request_id: str,
        created_time: int, model_name: str,
        encoding_format: str) -> EmbeddingResponse:
28
    data: List[EmbeddingResponseData] = []
29
30
31
    num_prompt_tokens = 0
    for idx, final_res in enumerate(final_res_batch):
        prompt_token_ids = final_res.prompt_token_ids
32
33
        embedding = final_res.outputs.embedding
        if encoding_format == "base64":
34
35
            embedding_bytes = np.array(embedding).tobytes()
            embedding = base64.b64encode(embedding_bytes).decode("utf-8")
36
        embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        data.append(embedding_data)

        num_prompt_tokens += len(prompt_token_ids)

    usage = UsageInfo(
        prompt_tokens=num_prompt_tokens,
        total_tokens=num_prompt_tokens,
    )

    return EmbeddingResponse(
        id=request_id,
        created=created_time,
        model=model_name,
        data=data,
        usage=usage,
    )


class OpenAIServingEmbedding(OpenAIServing):

57
58
    def __init__(
        self,
59
        async_engine_client: AsyncEngineClient,
60
61
62
63
64
        model_config: ModelConfig,
        served_model_names: List[str],
        *,
        request_logger: Optional[RequestLogger],
    ):
65
        super().__init__(async_engine_client=async_engine_client,
66
67
                         model_config=model_config,
                         served_model_names=served_model_names,
68
69
70
                         lora_modules=None,
                         prompt_adapters=None,
                         request_logger=request_logger)
71
72
73
74
75
76
77
78
79
80
81
82
83
        self._check_embedding_mode(model_config.embedding_mode)

    async def create_embedding(self, request: EmbeddingRequest,
                               raw_request: Request):
        """Completion API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/embeddings/create
        for the API specification. This API mimics the OpenAI Embedding API.
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

84
85
        encoding_format = (request.encoding_format
                           if request.encoding_format else "float")
86
87
88
89
90
        if request.dimensions is not None:
            return self.create_error_response(
                "dimensions is currently not supported")

        model_name = request.model
91
        request_id = f"embd-{random_uuid()}"
92
93
94
        created_time = int(time.monotonic())

        # Schedule the request and get the result generator.
95
        generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
96
        try:
97
98
99
100
101
            (
                lora_request,
                prompt_adapter_request,
            ) = self._maybe_get_adapters(request)

102
103
            tokenizer = await self.async_engine_client.get_tokenizer(
                lora_request)
104

105
106
            pooling_params = request.to_pooling_params()

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            prompts = list(
                self._tokenize_prompt_input_or_inputs(
                    request,
                    tokenizer,
                    request.input,
                ))

            for i, prompt_inputs in enumerate(prompts):
                request_id_item = f"{request_id}-{i}"

                self._log_inputs(request_id_item,
                                 prompt_inputs,
                                 params=pooling_params,
                                 lora_request=lora_request,
                                 prompt_adapter_request=prompt_adapter_request)

                if prompt_adapter_request is not None:
                    raise NotImplementedError(
                        "Prompt adapter is not supported "
                        "for embedding models")
127

128
                generator = self.async_engine_client.encode(
129
                    {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
130
                    pooling_params,
131
132
                    request_id_item,
                    lora_request=lora_request,
133
134
135
                )

                generators.append(generator)
136
137
138
139
140
141
142
143
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

        result_generator: AsyncIterator[Tuple[
            int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)

        # Non-streaming response
144
145
146
147
148
149
        final_res_batch: List[Optional[EmbeddingRequestOutput]]
        final_res_batch = [None] * len(prompts)
        try:
            async for i, res in result_generator:
                if await raw_request.is_disconnected():
                    # Abort the request if the client disconnects.
150
                    await self.async_engine_client.abort(f"{request_id}-{i}")
151
152
                    return self.create_error_response("Client disconnected")
                final_res_batch[i] = res
153
154
155
156
157
158
159

            for final_res in final_res_batch:
                assert final_res is not None

            final_res_batch_checked = cast(List[EmbeddingRequestOutput],
                                           final_res_batch)

160
            response = request_output_to_embedding_response(
161
                final_res_batch_checked, request_id, created_time, model_name,
162
                encoding_format)
163
164
165
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))
166
167
168
169
170
171
172
173
174

        return response

    def _check_embedding_mode(self, embedding_mode: bool):
        if not embedding_mode:
            logger.warning(
                "embedding_mode is False. Embedding API will not work.")
        else:
            logger.info("Activating the server engine with embedding enabled.")