"vscode:/vscode.git/clone" did not exist on "f67ce05d0b826322f85403f1113f69ca3853aa39"
serving_embedding.py 26.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import base64
5
6
from collections.abc import AsyncGenerator, Mapping
from typing import Any, Final, Literal, Optional, Union, cast
7

8
import numpy as np
9
import torch
10
from fastapi import Request
11
from typing_extensions import assert_never, override
12
13

from vllm.config import ModelConfig
14
from vllm.engine.protocol import EngineClient
15
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
16
from vllm.entrypoints.logger import RequestLogger
17
18
# yapf conflicts with isort for this docstring
# yapf: disable
19
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
20
                                              EmbeddingCompletionRequest,
21
                                              EmbeddingRequest,
22
                                              EmbeddingResponse,
23
24
                                              EmbeddingResponseData,
                                              ErrorResponse, UsageInfo)
25
26
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
                                                    OpenAIServing,
27
28
29
                                                    ServeContext,
                                                    TextTokensPrompt)
# yapf: enable
30
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
31
from vllm.entrypoints.renderer import RenderConfig
32
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
33
from vllm.logger import init_logger
34
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
35
                          PoolingOutput, PoolingRequestOutput, RequestOutput)
36
from vllm.pooling_params import PoolingParams
37
from vllm.utils import chunk_list
38
39
40
41

logger = init_logger(__name__)


42
def _get_embedding(
43
    output: EmbeddingOutput,
44
    encoding_format: Literal["float", "base64"],
45
) -> Union[list[float], str]:
46
47
48
    if encoding_format == "float":
        return output.embedding
    elif encoding_format == "base64":
49
50
51
        # Force to use float32 for base64 encoding
        # to match the OpenAI python client behavior
        embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
52
53
54
55
56
        return base64.b64encode(embedding_bytes).decode("utf-8")

    assert_never(encoding_format)


57
class EmbeddingMixin(OpenAIServing):
58

59
60
61
62
63
64
65
66
67
68
69
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        pooler_config = self.model_config.pooler_config

        # Avoid repeated attribute lookups
        self.supports_chunked_processing = bool(
            pooler_config and pooler_config.enable_chunked_processing)
        self.max_embed_len = (pooler_config.max_embed_len if pooler_config
                              and pooler_config.max_embed_len else None)

70
    @override
71
    async def _preprocess(
72
        self,
73
74
75
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        ctx = cast(EmbeddingServeContext, ctx)
76
        try:
77
            ctx.lora_request = self._maybe_get_adapters(ctx.request)
78

79
            tokenizer = await self.engine_client.get_tokenizer()
80
            renderer = self._get_renderer(tokenizer)
81

82
            if isinstance(ctx.request, EmbeddingChatRequest):
83
84
                (
                    _,
85
                    _,
86
                    ctx.engine_prompts,
87
                ) = await self._preprocess_chat(
88
                    ctx.request,
89
                    tokenizer,
90
91
92
93
                    ctx.request.messages,
                    chat_template=ctx.request.chat_template
                    or ctx.chat_template,
                    chat_template_content_format=ctx.
94
                    chat_template_content_format,
95
                    add_generation_prompt=ctx.request.add_generation_prompt,
96
                    continue_final_message=False,
97
                    add_special_tokens=ctx.request.add_special_tokens,
98
99
                )
            else:
100
101
                ctx.engine_prompts = await renderer.render_prompt(
                    prompt_or_prompts=ctx.request.input,
102
                    config=self._build_render_config(ctx.request),
103
                )
104
            return None
105
        except (ValueError, TypeError) as e:
106
107
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
108

109
110
111
112
113
114
115
116
117
118
119
120
121
    def _build_render_config(
            self, request: EmbeddingCompletionRequest) -> RenderConfig:
        # Set max_length based on chunked processing capability
        if self._should_use_chunked_processing(request):
            max_length = None
        else:
            max_length = self.max_embed_len or self.max_model_len

        return RenderConfig(
            max_length=max_length,
            truncate_prompt_tokens=request.truncate_prompt_tokens,
            add_special_tokens=request.add_special_tokens)

122
    @override
123
    def _build_response(
124
        self,
125
126
        ctx: ServeContext,
    ) -> Union[EmbeddingResponse, ErrorResponse]:
127
        items: list[EmbeddingResponseData] = []
128
129
        num_prompt_tokens = 0

130
131
132
133
        final_res_batch_checked = cast(list[PoolingRequestOutput],
                                       ctx.final_res_batch)

        for idx, final_res in enumerate(final_res_batch_checked):
134
135
136
137
138
            embedding_res = EmbeddingRequestOutput.from_base(final_res)

            item = EmbeddingResponseData(
                index=idx,
                embedding=_get_embedding(embedding_res.outputs,
139
                                         ctx.request.encoding_format),
140
141
142
143
144
145
146
147
148
149
150
151
            )
            prompt_token_ids = final_res.prompt_token_ids

            items.append(item)
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

        return EmbeddingResponse(
152
153
154
            id=ctx.request_id,
            created=ctx.created_time,
            model=ctx.model_name,
155
156
157
            data=items,
            usage=usage,
        )
158

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    def _get_max_position_embeddings(self) -> int:
        """Get the model's effective maximum sequence length for chunking."""
        return self.model_config.max_model_len

    def _should_use_chunked_processing(self, request) -> bool:
        """Check if chunked processing should be used for this request."""
        return isinstance(
            request,
            (EmbeddingCompletionRequest,
             EmbeddingChatRequest)) and self.supports_chunked_processing

    async def _process_chunked_request(
        self,
        ctx: EmbeddingServeContext,
        original_prompt: TextTokensPrompt,
        pooling_params,
        trace_headers,
        prompt_idx: int,
    ) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
        """Process a single prompt using chunked processing."""
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
        token_ids = original_prompt["prompt_token_ids"]

        # Split into chunks using max_position_embeddings
        max_pos_embeddings = self._get_max_position_embeddings()
        # Process all chunks for MEAN aggregation
        for chunk_idx, chunk_tokens in enumerate(
                chunk_list(token_ids, max_pos_embeddings)):
            # Create a request ID for this chunk
            chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-"
                                f"chunk-{chunk_idx}")

            # Create engine prompt for this chunk
            chunk_engine_prompt = EngineTokensPrompt(
                prompt_token_ids=chunk_tokens)

            # Create chunk request prompt for logging
            chunk_text = ""
            chunk_request_prompt = TextTokensPrompt(
                prompt=chunk_text, prompt_token_ids=chunk_tokens)

            # Log the chunk
            self._log_inputs(chunk_request_id,
                             chunk_request_prompt,
                             params=pooling_params,
                             lora_request=ctx.lora_request)

            # Create generator for this chunk and wrap it to return indices
            original_generator = self.engine_client.encode(
                chunk_engine_prompt,
                pooling_params,
                chunk_request_id,
                lora_request=ctx.lora_request,
                trace_headers=trace_headers,
                priority=getattr(ctx.request, "priority", 0),
            )

            generators.append(original_generator)

        return generators

    def _validate_input(
        self,
        request,
        input_ids: list[int],
        input_text: str,
    ) -> TextTokensPrompt:
        """Override to support chunked processing for embedding requests."""
        token_num = len(input_ids)

        # Note: EmbeddingRequest doesn't have max_tokens
        if isinstance(request,
                      (EmbeddingCompletionRequest, EmbeddingChatRequest)):
            # Check if chunked processing is enabled for pooling models
            enable_chunked = self._should_use_chunked_processing(request)

            # Use max_position_embeddings for chunked processing decisions
            max_pos_embeddings = self._get_max_position_embeddings()

            # Determine the effective max length for validation
            if self.max_embed_len is not None:
                # Use max_embed_len for validation instead of max_model_len
                length_type = "maximum embedding input length"
                max_length_value = self.max_embed_len
            else:
                # Fall back to max_model_len validation (original behavior)
                length_type = "maximum context length"
                max_length_value = self.max_model_len

            validation_error_msg = (
                "This model's {length_type} is {max_length_value} tokens. "
                "However, you requested {token_num} tokens in the input for "
                "embedding generation. Please reduce the length of the input.")

            chunked_processing_error_msg = (
                "This model's {length_type} is {max_length_value} tokens. "
                "However, you requested {token_num} tokens in the input for "
                "embedding generation. Please reduce the length of the input "
                "or enable chunked processing.")

            # Check if input exceeds max length
            if token_num > max_length_value:
                raise ValueError(
                    validation_error_msg.format(
                        length_type=length_type,
                        max_length_value=max_length_value,
                        token_num=token_num))

            # Check for chunked processing
            # when exceeding max_position_embeddings
            if token_num > max_pos_embeddings:
                if enable_chunked:
                    # Allow long inputs when chunked processing is enabled
                    logger.info(
                        "Input length %s exceeds max_position_embeddings "
                        "%s, will use chunked processing", token_num,
                        max_pos_embeddings)
                else:
                    raise ValueError(
                        chunked_processing_error_msg.format(
                            length_type="maximum position embeddings length",
                            max_length_value=max_pos_embeddings,
                            token_num=token_num))

            return TextTokensPrompt(prompt=input_text,
                                    prompt_token_ids=input_ids)

        # For other request types, use the parent's implementation
        return super()._validate_input(request, input_ids, input_text)

    def _is_text_tokens_prompt(self, prompt) -> bool:
        """Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
        return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
                and "prompt_embeds" not in prompt)

    async def _create_single_prompt_generator(
        self,
        ctx: EmbeddingServeContext,
297
        engine_prompt: EngineTokensPrompt,
298
299
300
301
302
303
304
305
        pooling_params: PoolingParams,
        trace_headers: Optional[Mapping[str, str]],
        prompt_index: int,
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
        """Create a generator for a single prompt using standard processing."""
        request_id_item = f"{ctx.request_id}-{prompt_index}"

        self._log_inputs(request_id_item,
306
                         engine_prompt,
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
                         params=pooling_params,
                         lora_request=ctx.lora_request)

        # Return the original generator without wrapping
        return self.engine_client.encode(
            engine_prompt,
            pooling_params,
            request_id_item,
            lora_request=ctx.lora_request,
            trace_headers=trace_headers,
            priority=getattr(ctx.request, "priority", 0),
        )

    @override
    async def _prepare_generators(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Override to support chunked processing."""
        ctx = cast(EmbeddingServeContext, ctx)

        # Check if we should use chunked processing
        use_chunked = self._should_use_chunked_processing(ctx.request)

        # If no chunked processing needed, delegate to parent class
        if not use_chunked:
            return await super()._prepare_generators(ctx)

        # Custom logic for chunked processing
        generators: list[AsyncGenerator[Union[RequestOutput,
                                              PoolingRequestOutput],
                                        None]] = []

        try:
            trace_headers = (None if ctx.raw_request is None else await
                             self._get_trace_headers(ctx.raw_request.headers))

            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params

            # Verify and set the task for pooling params
            try:
                pooling_params.verify("embed", self.model_config)
            except ValueError as e:
                return self.create_error_response(str(e))

            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            max_pos_embeddings = self._get_max_position_embeddings()

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                # Check if this specific prompt needs chunked processing
362
                if self._is_text_tokens_prompt(engine_prompt):
363
364
                    # Cast to TextTokensPrompt since we've verified
                    # prompt_token_ids
365
                    text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
366
367
368
369
370
371
372
373
374
375
376
                    if (len(text_tokens_prompt["prompt_token_ids"])
                            > max_pos_embeddings):
                        # Use chunked processing for this prompt
                        chunk_generators = await self._process_chunked_request(
                            ctx, text_tokens_prompt, pooling_params,
                            trace_headers, i)
                        generators.extend(chunk_generators)
                        continue

                # Normal processing for short prompts or non-token prompts
                generator = await self._create_single_prompt_generator(
377
                    ctx, engine_prompt, pooling_params, trace_headers, i)
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
                generators.append(generator)

            from vllm.utils import merge_async_iterators
            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    @override
    async def _collect_batch(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Collect and aggregate batch results
        with support for chunked processing.
396
397

        For chunked requests, performs online aggregation to
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
        minimize memory usage.
        For regular requests, collects results normally.
        """
        ctx = cast(EmbeddingServeContext, ctx)
        try:
            if ctx.engine_prompts is None:
                return self.create_error_response(
                    "Engine prompts not available")

            # Check if we used chunked processing
            use_chunked = self._should_use_chunked_processing(ctx.request)

            if not use_chunked:
                return await super()._collect_batch(ctx=ctx)

            if ctx.result_generator is None:
                return self.create_error_response(
                    "Result generator not available")

            # Online aggregation for chunked requests to
            # minimize memory usage
            # Track aggregation state for each prompt
            prompt_aggregators: dict[int, dict[str, Any]] = {}
            short_prompts_results: dict[int, PoolingRequestOutput] = {}

            async for result_idx, result in ctx.result_generator:
                if "-chunk-" in result.request_id:
                    # Extract prompt_idx from chunked request_id
                    parts = result.request_id.split("-")
                    try:
                        prompt_idx = int(parts[parts.index("prompt") + 1])
                    except (ValueError, IndexError):
                        # Fallback: extract from result_idx if parsing fails
                        prompt_idx = result_idx

                    # Initialize aggregator for this prompt if needed
                    if prompt_idx not in prompt_aggregators:
                        prompt_aggregators[prompt_idx] = {
                            'weighted_sum': None,
                            'total_weight': 0,
                            'chunk_count': 0,
                            'request_id': result.request_id.split("-chunk-")[0]
                        }

                    aggregator = prompt_aggregators[prompt_idx]

                    # MEAN pooling with online weighted averaging
                    # Ensure result is PoolingRequestOutput
                    # for embedding processing
                    if not isinstance(result, PoolingRequestOutput):
                        return self.create_error_response(
                            f"Expected PoolingRequestOutput for "
                            f"chunked embedding, got "
                            f"{type(result).__name__}")

                    # Handle both PoolingOutput and
                    # EmbeddingOutput types
                    if hasattr(result.outputs, 'data'):
                        # PoolingOutput case
                        embedding_data = result.outputs.data
                    elif hasattr(result.outputs, 'embedding'):
                        # EmbeddingOutput case -
                        # convert embedding list to tensor
                        embedding_data = result.outputs.embedding
                    else:
                        return self.create_error_response(
                            f"Unsupported output type: "
                            f"{type(result.outputs).__name__}")

                    if not isinstance(embedding_data, torch.Tensor):
                        embedding_data = torch.tensor(embedding_data,
                                                      dtype=torch.float32)

                    if result.prompt_token_ids is None:
                        return self.create_error_response(
                            "prompt_token_ids cannot be None for "
                            "chunked processing")
                    weight = len(result.prompt_token_ids)

                    weighted_embedding = embedding_data.to(
                        dtype=torch.float32) * weight

                    if aggregator['weighted_sum'] is None:
                        # First chunk
                        aggregator['weighted_sum'] = weighted_embedding
                    else:
                        # Accumulate
                        aggregator['weighted_sum'] += weighted_embedding

                    aggregator['total_weight'] += weight
                    aggregator['chunk_count'] += 1
                else:
                    # Non-chunked result - extract prompt_idx from request_id
                    parts = result.request_id.split("-")
                    try:
                        # Last part should be prompt index
                        prompt_idx = int(parts[-1])
                    except (ValueError, IndexError):
                        prompt_idx = result_idx  # Fallback to result_idx

                    short_prompts_results[prompt_idx] = cast(
                        PoolingRequestOutput, result)

            # Finalize aggregated results
            final_res_batch: list[Union[PoolingRequestOutput,
                                        EmbeddingRequestOutput]] = []
            num_prompts = len(ctx.engine_prompts)

            for prompt_idx in range(num_prompts):
                if prompt_idx in prompt_aggregators:
                    # Finalize MEAN aggregation for this chunked prompt
                    aggregator = prompt_aggregators[prompt_idx]

                    weighted_sum = aggregator['weighted_sum']
                    total_weight = aggregator['total_weight']

                    if (weighted_sum is not None
                            and isinstance(weighted_sum, torch.Tensor)
                            and isinstance(total_weight,
                                           (int, float)) and total_weight > 0):

                        # Compute final mean embedding
                        final_embedding = weighted_sum / total_weight

                        # Create a PoolingRequestOutput
                        # for the aggregated result
                        pooling_output_data = PoolingOutput(
                            data=final_embedding)

                        # Get original prompt token IDs for this prompt
528
                        original_prompt = ctx.engine_prompts[prompt_idx]
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
                        if not self._is_text_tokens_prompt(original_prompt):
                            return self.create_error_response(
                                f"Chunked prompt {prompt_idx} is not a "
                                f"TextTokensPrompt")

                        original_token_ids = cast(
                            TextTokensPrompt,
                            original_prompt)["prompt_token_ids"]

                        pooling_request_output = PoolingRequestOutput(
                            request_id=aggregator['request_id'],
                            prompt_token_ids=original_token_ids,
                            outputs=pooling_output_data,
                            finished=True)

                        final_res_batch.append(pooling_request_output)
                    else:
                        return self.create_error_response(
                            f"Failed to aggregate chunks "
                            f"for prompt {prompt_idx}")
                elif prompt_idx in short_prompts_results:
                    final_res_batch.append(
                        cast(PoolingRequestOutput,
                             short_prompts_results[prompt_idx]))
                else:
                    return self.create_error_response(
                        f"Result not found for prompt {prompt_idx}")

            ctx.final_res_batch = cast(
                list[Union[RequestOutput, PoolingRequestOutput]],
                final_res_batch)

            return None

        except Exception as e:
            return self.create_error_response(str(e))

566
567
568
569
570
571
572
573
574
575
576
577
578

class OpenAIServingEmbedding(EmbeddingMixin):
    request_id_prefix = "embd"

    def __init__(
        self,
        engine_client: EngineClient,
        model_config: ModelConfig,
        models: OpenAIServingModels,
        *,
        request_logger: Optional[RequestLogger],
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
579
        trust_request_chat_template: bool = False,
580
        log_error_stack: bool = False,
581
582
583
584
    ) -> None:
        super().__init__(engine_client=engine_client,
                         model_config=model_config,
                         models=models,
585
586
                         request_logger=request_logger,
                         log_error_stack=log_error_stack)
587
588
589

        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
590
        self.trust_request_chat_template = trust_request_chat_template
591
592
593
594
595
596
597
598
599
600
601
602

    async def create_embedding(
        self,
        request: EmbeddingRequest,
        raw_request: Optional[Request] = None,
    ) -> Union[EmbeddingResponse, ErrorResponse]:
        """
        Embedding API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/embeddings/create
        for the API specification. This API mimics the OpenAI Embedding API.
        """
603
        model_name = self.models.model_name()
604
605
606
        request_id = (
            f"{self.request_id_prefix}-"
            f"{self._base_request_id(raw_request, request.request_id)}")
607
608
609
610
611
612
613
614
615
616
617
618

        ctx = EmbeddingServeContext(
            request=request,
            raw_request=raw_request,
            model_name=model_name,
            request_id=request_id,
            chat_template=self.chat_template,
            chat_template_content_format=self.chat_template_content_format,
        )

        return await super().handle(ctx)  # type: ignore

619
620
621
622
623
624
625
626
    @override
    def _create_pooling_params(
        self,
        ctx: ServeContext[EmbeddingRequest],
    ) -> Union[PoolingParams, ErrorResponse]:
        pooling_params = super()._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params
627
628

        try:
629
            pooling_params.verify("embed", self.model_config)
630
631
632
        except ValueError as e:
            return self.create_error_response(str(e))

633
        return pooling_params
634
635
636
637
638
639
640
641
642
643
644
645
646
647

    async def _preprocess(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        if isinstance(ctx.request, EmbeddingChatRequest):
            error_check_ret = self._validate_chat_template(
                request_chat_template=ctx.request.chat_template,
                chat_template_kwargs=ctx.request.chat_template_kwargs,
                trust_request_chat_template=self.trust_request_chat_template,
            )
            if error_check_ret is not None:
                return error_check_ret
        return await super()._preprocess(ctx)