serving_pooling.py 10.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import asyncio
import base64
import time
7
from collections.abc import AsyncGenerator
8
from typing import Final, Literal, Optional, Union, cast
9

10
import jinja2
11
import numpy as np
12
import torch
13
14
15
from fastapi import Request
from typing_extensions import assert_never

16
from vllm.config import VllmConfig
17
18
19
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
20
21
22
23
24
25
26
27
28
29
30
from vllm.entrypoints.openai.protocol import (
    ErrorResponse,
    IOProcessorRequest,
    IOProcessorResponse,
    PoolingChatRequest,
    PoolingCompletionRequest,
    PoolingRequest,
    PoolingResponse,
    PoolingResponseData,
    UsageInfo,
)
31
from vllm.entrypoints.openai.serving_engine import OpenAIServing
32
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
33
from vllm.entrypoints.renderer import RenderConfig
34
from vllm.entrypoints.utils import _validate_truncation_size
35
36
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
37
from vllm.plugins.io_processors import get_io_processor
38
39
40
41
42
43
44
45
from vllm.utils import merge_async_iterators

logger = init_logger(__name__)


def _get_data(
    output: PoolingOutput,
    encoding_format: Literal["float", "base64"],
46
) -> Union[list[float], str]:
47
48
49
50
51
    if encoding_format == "float":
        return output.data.tolist()
    elif encoding_format == "base64":
        # Force to use float32 for base64 encoding
        # to match the OpenAI python client behavior
52
53
        pt_float32 = output.data.to(dtype=torch.float32)
        pooling_bytes = np.array(pt_float32, dtype="float32").tobytes()
54
55
56
57
58
59
60
61
62
        return base64.b64encode(pooling_bytes).decode("utf-8")

    assert_never(encoding_format)


class OpenAIServingPooling(OpenAIServing):
    def __init__(
        self,
        engine_client: EngineClient,
63
        vllm_config: VllmConfig,
64
        models: OpenAIServingModels,
65
66
67
68
        *,
        request_logger: Optional[RequestLogger],
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
69
        trust_request_chat_template: bool = False,
70
        log_error_stack: bool = False,
71
    ) -> None:
72
73
74
75
76
77
78
        super().__init__(
            engine_client=engine_client,
            model_config=vllm_config.model_config,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )
79
80
81

        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
82
        self.trust_request_chat_template = trust_request_chat_template
83
84
        io_processor_plugin = self.model_config.io_processor_plugin
        self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
85
86
87
88
89

    async def create_pooling(
        self,
        request: PoolingRequest,
        raw_request: Optional[Request] = None,
90
    ) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]:
91
92
93
94
95
96
97
98
        """
        See https://platform.openai.com/docs/api-reference/embeddings/create
        for the API specification. This API mimics the OpenAI Embedding API.
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

99
        model_name = self.models.model_name()
100

101
102
103
        request_id = f"pool-{self._base_request_id(raw_request)}"
        created_time = int(time.time())

104
        is_io_processor_request = isinstance(request, IOProcessorRequest)
105
        try:
106
            lora_request = self._maybe_get_adapters(request)
107

108
109
110
            if self.model_config.skip_tokenizer_init:
                tokenizer = None
            else:
111
                tokenizer = await self.engine_client.get_tokenizer()
112
            renderer = self._get_renderer(tokenizer)
113

114
115
            if getattr(request, "dimensions", None) is not None:
                return self.create_error_response(
116
117
                    "dimensions is currently not supported"
                )
118

119
            truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
120
            truncate_prompt_tokens = _validate_truncation_size(
121
122
                self.max_model_len, truncate_prompt_tokens
            )
123
124
125
126
127
128
129

            if is_io_processor_request:
                if self.io_processor is None:
                    raise ValueError(
                        "No IOProcessor plugin installed. Please refer "
                        "to the documentation and to the "
                        "'prithvi_geospatial_mae_io_processor' "
130
131
                        "offline inference example for more details."
                    )
132
133
134
135

                validated_prompt = self.io_processor.parse_request(request)

                engine_prompts = await self.io_processor.pre_process_async(
136
137
                    prompt=validated_prompt, request_id=request_id
                )
138
139

            elif isinstance(request, PoolingChatRequest):
140
141
142
                error_check_ret = self._validate_chat_template(
                    request_chat_template=request.chat_template,
                    chat_template_kwargs=request.chat_template_kwargs,
143
                    trust_request_chat_template=self.trust_request_chat_template,
144
145
146
                )
                if error_check_ret is not None:
                    return error_check_ret
147
148
                (
                    _,
149
                    _,
150
151
152
153
154
155
                    engine_prompts,
                ) = await self._preprocess_chat(
                    request,
                    tokenizer,
                    request.messages,
                    chat_template=request.chat_template or self.chat_template,
156
                    chat_template_content_format=self.chat_template_content_format,
157
158
159
160
161
162
                    # In pooling requests, we are not generating tokens,
                    # so there is no need to append extra tokens to the input
                    add_generation_prompt=False,
                    continue_final_message=False,
                    add_special_tokens=request.add_special_tokens,
                )
163
            elif isinstance(request, PoolingCompletionRequest):
164
165
                engine_prompts = await renderer.render_prompt(
                    prompt_or_prompts=request.input,
166
                    config=self._build_render_config(request),
167
                )
168
            else:
169
                raise ValueError(f"Unsupported request of type {type(request)}")
170
        except (ValueError, TypeError, jinja2.TemplateError) as e:
171
172
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
173
174

        # Schedule the request and get the result generator.
175
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
176
177
178
        try:
            pooling_params = request.to_pooling_params()

179
180
181
182
183
            try:
                pooling_params.verify("encode", self.model_config)
            except ValueError as e:
                return self.create_error_response(str(e))

184
185
186
            for i, engine_prompt in enumerate(engine_prompts):
                request_id_item = f"{request_id}-{i}"

187
188
189
190
191
192
                self._log_inputs(
                    request_id_item,
                    engine_prompt,
                    params=pooling_params,
                    lora_request=lora_request,
                )
193

194
195
196
197
198
                trace_headers = (
                    None
                    if raw_request is None
                    else await self._get_trace_headers(raw_request.headers)
                )
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

                generator = self.engine_client.encode(
                    engine_prompt,
                    pooling_params,
                    request_id_item,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=request.priority,
                )

                generators.append(generator)
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

        result_generator = merge_async_iterators(*generators)

216
217
218
219
220
221
222
223
        if is_io_processor_request:
            assert self.io_processor is not None
            output = await self.io_processor.post_process_async(
                model_output=result_generator,
                request_id=request_id,
            )
            return self.io_processor.output_to_response(output)

224
        assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest))
225
226
227
        num_prompts = len(engine_prompts)

        # Non-streaming response
228
        final_res_batch: list[Optional[PoolingRequestOutput]]
229
230
231
232
233
234
235
        final_res_batch = [None] * num_prompts
        try:
            async for i, res in result_generator:
                final_res_batch[i] = res

            assert all(final_res is not None for final_res in final_res_batch)

236
            final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch)
237
238
239
240
241
242

            response = self.request_output_to_pooling_response(
                final_res_batch_checked,
                request_id,
                created_time,
                model_name,
243
                request.encoding_format,
244
245
246
247
248
249
250
251
252
253
254
            )
        except asyncio.CancelledError:
            return self.create_error_response("Client disconnected")
        except ValueError as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

        return response

    def request_output_to_pooling_response(
        self,
255
        final_res_batch: list[PoolingRequestOutput],
256
257
258
259
260
        request_id: str,
        created_time: int,
        model_name: str,
        encoding_format: Literal["float", "base64"],
    ) -> PoolingResponse:
261
        items: list[PoolingResponseData] = []
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        num_prompt_tokens = 0

        for idx, final_res in enumerate(final_res_batch):
            item = PoolingResponseData(
                index=idx,
                data=_get_data(final_res.outputs, encoding_format),
            )
            prompt_token_ids = final_res.prompt_token_ids

            items.append(item)
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

        return PoolingResponse(
            id=request_id,
            created=created_time,
            model=model_name,
            data=items,
            usage=usage,
        )
286

287
    def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig:
288
289
290
        return RenderConfig(
            max_length=self.max_model_len,
            truncate_prompt_tokens=request.truncate_prompt_tokens,
291
292
            add_special_tokens=request.add_special_tokens,
        )