serving.py 5.59 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
from collections.abc import Callable
5
from functools import partial
6
from typing import Literal, TypeAlias, cast
7

8
from fastapi.responses import JSONResponse, StreamingResponse
9
from typing_extensions import assert_never
10

11
12
13
14
15
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
16
17
18
19
20
21
from vllm.entrypoints.pooling.embed.protocol import (
    EmbeddingBytesResponse,
    EmbeddingRequest,
    EmbeddingResponse,
    EmbeddingResponseData,
)
22
from vllm.entrypoints.pooling.typing import PoolingServeContext
23
24
25
26
from vllm.entrypoints.pooling.utils import (
    encode_pooling_bytes,
    encode_pooling_output_base64,
    encode_pooling_output_float,
27
    get_json_response_cls,
28
)
29
30
from vllm.outputs import PoolingRequestOutput
from vllm.renderers import BaseRenderer
31
from vllm.utils.serial_utils import EmbedDType, Endianness
32

33
JSONResponseCLS = get_json_response_cls()
34

35
EmbeddingServeContext: TypeAlias = PoolingServeContext[EmbeddingRequest]
36

37

38
39
40
41
42
43
44
class ServingEmbedding(PoolingServing):
    """
    Embedding API similar to OpenAI's API.

    See https://platform.openai.com/docs/api-reference/embeddings/create
    for the API specification. This API mimics the OpenAI Embedding API.
    """
45
46
47

    request_id_prefix = "embd"

48
    def init_io_processor(
49
        self,
50
51
52
53
54
55
56
57
        model_config: ModelConfig,
        renderer: BaseRenderer,
        chat_template_config: ChatTemplateConfig,
    ) -> EmbedIOProcessor:
        return EmbedIOProcessor(
            model_config=model_config,
            renderer=renderer,
            chat_template_config=chat_template_config,
58
59
        )

60
    async def _build_response(
61
        self,
62
        ctx: EmbeddingServeContext,
63
64
65
66
    ) -> JSONResponse | StreamingResponse:
        encoding_format = ctx.request.encoding_format
        embed_dtype = ctx.request.embed_dtype
        endianness = ctx.request.endianness
67

68
69
70
71
72
73
74
75
76
        if encoding_format == "float" or encoding_format == "base64":
            return self._request_output_to_embed_json_response(
                ctx.final_res_batch,
                ctx.request_id,
                ctx.created_time,
                ctx.model_name,
                encoding_format,
                embed_dtype,
                endianness,
77
78
            )

79
80
81
82
83
84
85
86
87
        if encoding_format == "bytes" or encoding_format == "bytes_only":
            return self._request_output_to_to_embed_bytes_response(
                ctx.final_res_batch,
                ctx.request_id,
                ctx.created_time,
                ctx.model_name,
                encoding_format,
                embed_dtype,
                endianness,
88
            )
89

90
        assert_never(encoding_format)
91

92
    def _request_output_to_embed_json_response(
93
94
95
96
97
98
99
100
        self,
        final_res_batch: list[PoolingRequestOutput],
        request_id: str,
        created_time: int,
        model_name: str,
        encoding_format: Literal["float", "base64"],
        embed_dtype: EmbedDType,
        endianness: Endianness,
101
    ) -> JSONResponse:
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        encode_fn = cast(
            Callable[[PoolingRequestOutput], list[float] | str],
            (
                encode_pooling_output_float
                if encoding_format == "float"
                else partial(
                    encode_pooling_output_base64,
                    embed_dtype=embed_dtype,
                    endianness=endianness,
                )
            ),
        )

        items: list[EmbeddingResponseData] = []
        num_prompt_tokens = 0

        for idx, final_res in enumerate(final_res_batch):
            item = EmbeddingResponseData(
                index=idx,
                embedding=encode_fn(final_res),
            )
            prompt_token_ids = final_res.prompt_token_ids

            items.append(item)
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

133
        response = EmbeddingResponse(
134
135
136
137
138
139
            id=request_id,
            created=created_time,
            model=model_name,
            data=items,
            usage=usage,
        )
140
        return JSONResponseCLS(content=response.model_dump())
141

142
    def _request_output_to_to_embed_bytes_response(
143
144
145
146
147
148
149
150
        self,
        final_res_batch: list[PoolingRequestOutput],
        request_id: str,
        created_time: int,
        model_name: str,
        encoding_format: Literal["bytes", "bytes_only"],
        embed_dtype: EmbedDType,
        endianness: Endianness,
151
    ) -> StreamingResponse:
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        content, items, usage = encode_pooling_bytes(
            pooling_outputs=final_res_batch,
            embed_dtype=embed_dtype,
            endianness=endianness,
        )

        headers = (
            None
            if encoding_format == "bytes_only"
            else {
                "metadata": json.dumps(
                    {
                        "id": request_id,
                        "created": created_time,
                        "model": model_name,
                        "data": items,
                        "usage": usage,
                    }
                )
            }
        )

174
175
176
177
178
        response = EmbeddingBytesResponse(content=content, headers=headers)
        return StreamingResponse(
            content=response.content,
            headers=response.headers,
            media_type=response.media_type,
179
        )