serving_embedding.py 25.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import base64
5
6
from collections.abc import AsyncGenerator, Mapping
from typing import Any, Final, Literal, Optional, Union, cast
7

8
import numpy as np
9
import torch
10
from fastapi import Request
11
from typing_extensions import assert_never, override
12
13

from vllm.config import ModelConfig
14
from vllm.engine.protocol import EngineClient
15
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
16
from vllm.entrypoints.logger import RequestLogger
17

18
19
# yapf conflicts with isort for this docstring
# yapf: disable
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from vllm.entrypoints.openai.protocol import (
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
    EmbeddingResponseData,
    ErrorResponse,
    UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import (
    EmbeddingServeContext,
    OpenAIServing,
    ServeContext,
    TextTokensPrompt,
)

36
# yapf: enable
37
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
38
from vllm.entrypoints.renderer import RenderConfig
39
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
40
from vllm.logger import init_logger
41
42
43
44
45
46
47
from vllm.outputs import (
    EmbeddingOutput,
    EmbeddingRequestOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
48
from vllm.pooling_params import PoolingParams
49
from vllm.utils import chunk_list
50
51
52
53

logger = init_logger(__name__)


54
def _get_embedding(
55
    output: EmbeddingOutput,
56
    encoding_format: Literal["float", "base64"],
57
) -> Union[list[float], str]:
58
59
60
    if encoding_format == "float":
        return output.embedding
    elif encoding_format == "base64":
61
62
63
        # Force to use float32 for base64 encoding
        # to match the OpenAI python client behavior
        embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
64
65
66
67
68
        return base64.b64encode(embedding_bytes).decode("utf-8")

    assert_never(encoding_format)


69
class EmbeddingMixin(OpenAIServing):
70
71
72
73
74
75
76
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        pooler_config = self.model_config.pooler_config

        # Avoid repeated attribute lookups
        self.supports_chunked_processing = bool(
77
78
79
80
81
82
83
            pooler_config and pooler_config.enable_chunked_processing
        )
        self.max_embed_len = (
            pooler_config.max_embed_len
            if pooler_config and pooler_config.max_embed_len
            else None
        )
84

85
    @override
86
    async def _preprocess(
87
        self,
88
89
90
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        ctx = cast(EmbeddingServeContext, ctx)
91
        try:
92
            ctx.lora_request = self._maybe_get_adapters(ctx.request)
93

94
            tokenizer = await self.engine_client.get_tokenizer()
95
            renderer = self._get_renderer(tokenizer)
96

97
            if isinstance(ctx.request, EmbeddingChatRequest):
98
99
                (
                    _,
100
                    _,
101
                    ctx.engine_prompts,
102
                ) = await self._preprocess_chat(
103
                    ctx.request,
104
                    tokenizer,
105
                    ctx.request.messages,
106
107
                    chat_template=ctx.request.chat_template or ctx.chat_template,
                    chat_template_content_format=ctx.chat_template_content_format,
108
                    add_generation_prompt=ctx.request.add_generation_prompt,
109
                    continue_final_message=False,
110
                    add_special_tokens=ctx.request.add_special_tokens,
111
112
                )
            else:
113
114
                ctx.engine_prompts = await renderer.render_prompt(
                    prompt_or_prompts=ctx.request.input,
115
                    config=self._build_render_config(ctx.request),
116
                )
117
            return None
118
        except (ValueError, TypeError) as e:
119
120
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
121

122
    def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig:
123
124
125
126
127
128
129
130
131
        # Set max_length based on chunked processing capability
        if self._should_use_chunked_processing(request):
            max_length = None
        else:
            max_length = self.max_embed_len or self.max_model_len

        return RenderConfig(
            max_length=max_length,
            truncate_prompt_tokens=request.truncate_prompt_tokens,
132
133
            add_special_tokens=request.add_special_tokens,
        )
134

135
    @override
136
    def _build_response(
137
        self,
138
139
        ctx: ServeContext,
    ) -> Union[EmbeddingResponse, ErrorResponse]:
140
        items: list[EmbeddingResponseData] = []
141
142
        num_prompt_tokens = 0

143
        final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
144
145

        for idx, final_res in enumerate(final_res_batch_checked):
146
147
148
149
            embedding_res = EmbeddingRequestOutput.from_base(final_res)

            item = EmbeddingResponseData(
                index=idx,
150
151
152
                embedding=_get_embedding(
                    embedding_res.outputs, ctx.request.encoding_format
                ),
153
154
155
156
157
158
159
160
161
162
163
164
            )
            prompt_token_ids = final_res.prompt_token_ids

            items.append(item)
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

        return EmbeddingResponse(
165
166
167
            id=ctx.request_id,
            created=ctx.created_time,
            model=ctx.model_name,
168
169
170
            data=items,
            usage=usage,
        )
171

172
173
174
175
176
177
    def _get_max_position_embeddings(self) -> int:
        """Get the model's effective maximum sequence length for chunking."""
        return self.model_config.max_model_len

    def _should_use_chunked_processing(self, request) -> bool:
        """Check if chunked processing should be used for this request."""
178
179
180
181
        return (
            isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest))
            and self.supports_chunked_processing
        )
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

    async def _process_chunked_request(
        self,
        ctx: EmbeddingServeContext,
        original_prompt: TextTokensPrompt,
        pooling_params,
        trace_headers,
        prompt_idx: int,
    ) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
        """Process a single prompt using chunked processing."""
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
        token_ids = original_prompt["prompt_token_ids"]

        # Split into chunks using max_position_embeddings
        max_pos_embeddings = self._get_max_position_embeddings()
        # Process all chunks for MEAN aggregation
        for chunk_idx, chunk_tokens in enumerate(
199
200
            chunk_list(token_ids, max_pos_embeddings)
        ):
201
            # Create a request ID for this chunk
202
            chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
203
204

            # Create engine prompt for this chunk
205
            chunk_engine_prompt = EngineTokensPrompt(prompt_token_ids=chunk_tokens)
206
207
208
209

            # Create chunk request prompt for logging
            chunk_text = ""
            chunk_request_prompt = TextTokensPrompt(
210
211
                prompt=chunk_text, prompt_token_ids=chunk_tokens
            )
212
213

            # Log the chunk
214
215
216
217
218
219
            self._log_inputs(
                chunk_request_id,
                chunk_request_prompt,
                params=pooling_params,
                lora_request=ctx.lora_request,
            )
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

            # Create generator for this chunk and wrap it to return indices
            original_generator = self.engine_client.encode(
                chunk_engine_prompt,
                pooling_params,
                chunk_request_id,
                lora_request=ctx.lora_request,
                trace_headers=trace_headers,
                priority=getattr(ctx.request, "priority", 0),
            )

            generators.append(original_generator)

        return generators

    def _validate_input(
        self,
        request,
        input_ids: list[int],
        input_text: str,
    ) -> TextTokensPrompt:
        """Override to support chunked processing for embedding requests."""
        token_num = len(input_ids)

        # Note: EmbeddingRequest doesn't have max_tokens
245
        if isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest)):
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
            # Check if chunked processing is enabled for pooling models
            enable_chunked = self._should_use_chunked_processing(request)

            # Use max_position_embeddings for chunked processing decisions
            max_pos_embeddings = self._get_max_position_embeddings()

            # Determine the effective max length for validation
            if self.max_embed_len is not None:
                # Use max_embed_len for validation instead of max_model_len
                length_type = "maximum embedding input length"
                max_length_value = self.max_embed_len
            else:
                # Fall back to max_model_len validation (original behavior)
                length_type = "maximum context length"
                max_length_value = self.max_model_len

            validation_error_msg = (
                "This model's {length_type} is {max_length_value} tokens. "
                "However, you requested {token_num} tokens in the input for "
265
266
                "embedding generation. Please reduce the length of the input."
            )
267
268
269
270
271

            chunked_processing_error_msg = (
                "This model's {length_type} is {max_length_value} tokens. "
                "However, you requested {token_num} tokens in the input for "
                "embedding generation. Please reduce the length of the input "
272
273
                "or enable chunked processing."
            )
274
275
276
277
278
279
280

            # Check if input exceeds max length
            if token_num > max_length_value:
                raise ValueError(
                    validation_error_msg.format(
                        length_type=length_type,
                        max_length_value=max_length_value,
281
282
283
                        token_num=token_num,
                    )
                )
284
285
286
287
288
289
290
291

            # Check for chunked processing
            # when exceeding max_position_embeddings
            if token_num > max_pos_embeddings:
                if enable_chunked:
                    # Allow long inputs when chunked processing is enabled
                    logger.info(
                        "Input length %s exceeds max_position_embeddings "
292
293
294
295
                        "%s, will use chunked processing",
                        token_num,
                        max_pos_embeddings,
                    )
296
297
298
299
300
                else:
                    raise ValueError(
                        chunked_processing_error_msg.format(
                            length_type="maximum position embeddings length",
                            max_length_value=max_pos_embeddings,
301
302
303
                            token_num=token_num,
                        )
                    )
304

305
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
306
307
308
309
310
311

        # For other request types, use the parent's implementation
        return super()._validate_input(request, input_ids, input_text)

    def _is_text_tokens_prompt(self, prompt) -> bool:
        """Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
312
313
314
315
316
        return (
            isinstance(prompt, dict)
            and "prompt_token_ids" in prompt
            and "prompt_embeds" not in prompt
        )
317
318
319
320

    async def _create_single_prompt_generator(
        self,
        ctx: EmbeddingServeContext,
321
        engine_prompt: EngineTokensPrompt,
322
323
324
325
326
327
328
        pooling_params: PoolingParams,
        trace_headers: Optional[Mapping[str, str]],
        prompt_index: int,
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
        """Create a generator for a single prompt using standard processing."""
        request_id_item = f"{ctx.request_id}-{prompt_index}"

329
330
331
332
333
334
        self._log_inputs(
            request_id_item,
            engine_prompt,
            params=pooling_params,
            lora_request=ctx.lora_request,
        )
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361

        # Return the original generator without wrapping
        return self.engine_client.encode(
            engine_prompt,
            pooling_params,
            request_id_item,
            lora_request=ctx.lora_request,
            trace_headers=trace_headers,
            priority=getattr(ctx.request, "priority", 0),
        )

    @override
    async def _prepare_generators(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Override to support chunked processing."""
        ctx = cast(EmbeddingServeContext, ctx)

        # Check if we should use chunked processing
        use_chunked = self._should_use_chunked_processing(ctx.request)

        # If no chunked processing needed, delegate to parent class
        if not use_chunked:
            return await super()._prepare_generators(ctx)

        # Custom logic for chunked processing
362
363
364
        generators: list[
            AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]
        ] = []
365
366

        try:
367
368
369
370
371
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
372
373
374
375
376
377
378
379
380
381
382
383

            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params

            # Verify and set the task for pooling params
            try:
                pooling_params.verify("embed", self.model_config)
            except ValueError as e:
                return self.create_error_response(str(e))

            if ctx.engine_prompts is None:
384
                return self.create_error_response("Engine prompts not available")
385
386
387
388
389

            max_pos_embeddings = self._get_max_position_embeddings()

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                # Check if this specific prompt needs chunked processing
390
                if self._is_text_tokens_prompt(engine_prompt):
391
392
                    # Cast to TextTokensPrompt since we've verified
                    # prompt_token_ids
393
                    text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
394
                    if len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings:
395
396
                        # Use chunked processing for this prompt
                        chunk_generators = await self._process_chunked_request(
397
398
                            ctx, text_tokens_prompt, pooling_params, trace_headers, i
                        )
399
400
401
402
403
                        generators.extend(chunk_generators)
                        continue

                # Normal processing for short prompts or non-token prompts
                generator = await self._create_single_prompt_generator(
404
405
                    ctx, engine_prompt, pooling_params, trace_headers, i
                )
406
407
408
                generators.append(generator)

            from vllm.utils import merge_async_iterators
409

410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    @override
    async def _collect_batch(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        """Collect and aggregate batch results
        with support for chunked processing.
425
426

        For chunked requests, performs online aggregation to
427
428
429
430
431
432
        minimize memory usage.
        For regular requests, collects results normally.
        """
        ctx = cast(EmbeddingServeContext, ctx)
        try:
            if ctx.engine_prompts is None:
433
                return self.create_error_response("Engine prompts not available")
434
435
436
437
438
439
440
441

            # Check if we used chunked processing
            use_chunked = self._should_use_chunked_processing(ctx.request)

            if not use_chunked:
                return await super()._collect_batch(ctx=ctx)

            if ctx.result_generator is None:
442
                return self.create_error_response("Result generator not available")
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462

            # Online aggregation for chunked requests to
            # minimize memory usage
            # Track aggregation state for each prompt
            prompt_aggregators: dict[int, dict[str, Any]] = {}
            short_prompts_results: dict[int, PoolingRequestOutput] = {}

            async for result_idx, result in ctx.result_generator:
                if "-chunk-" in result.request_id:
                    # Extract prompt_idx from chunked request_id
                    parts = result.request_id.split("-")
                    try:
                        prompt_idx = int(parts[parts.index("prompt") + 1])
                    except (ValueError, IndexError):
                        # Fallback: extract from result_idx if parsing fails
                        prompt_idx = result_idx

                    # Initialize aggregator for this prompt if needed
                    if prompt_idx not in prompt_aggregators:
                        prompt_aggregators[prompt_idx] = {
463
464
465
466
                            "weighted_sum": None,
                            "total_weight": 0,
                            "chunk_count": 0,
                            "request_id": result.request_id.split("-chunk-")[0],
467
468
469
470
471
472
473
474
475
476
477
                        }

                    aggregator = prompt_aggregators[prompt_idx]

                    # MEAN pooling with online weighted averaging
                    # Ensure result is PoolingRequestOutput
                    # for embedding processing
                    if not isinstance(result, PoolingRequestOutput):
                        return self.create_error_response(
                            f"Expected PoolingRequestOutput for "
                            f"chunked embedding, got "
478
479
                            f"{type(result).__name__}"
                        )
480
481
482

                    # Handle both PoolingOutput and
                    # EmbeddingOutput types
483
                    if hasattr(result.outputs, "data"):
484
485
                        # PoolingOutput case
                        embedding_data = result.outputs.data
486
                    elif hasattr(result.outputs, "embedding"):
487
488
489
490
491
                        # EmbeddingOutput case -
                        # convert embedding list to tensor
                        embedding_data = result.outputs.embedding
                    else:
                        return self.create_error_response(
492
493
                            f"Unsupported output type: {type(result.outputs).__name__}"
                        )
494
495

                    if not isinstance(embedding_data, torch.Tensor):
496
497
498
                        embedding_data = torch.tensor(
                            embedding_data, dtype=torch.float32
                        )
499
500
501

                    if result.prompt_token_ids is None:
                        return self.create_error_response(
502
503
                            "prompt_token_ids cannot be None for chunked processing"
                        )
504
505
                    weight = len(result.prompt_token_ids)

506
                    weighted_embedding = embedding_data.to(dtype=torch.float32) * weight
507

508
                    if aggregator["weighted_sum"] is None:
509
                        # First chunk
510
                        aggregator["weighted_sum"] = weighted_embedding
511
512
                    else:
                        # Accumulate
513
                        aggregator["weighted_sum"] += weighted_embedding
514

515
516
                    aggregator["total_weight"] += weight
                    aggregator["chunk_count"] += 1
517
518
519
520
521
522
523
524
525
526
                else:
                    # Non-chunked result - extract prompt_idx from request_id
                    parts = result.request_id.split("-")
                    try:
                        # Last part should be prompt index
                        prompt_idx = int(parts[-1])
                    except (ValueError, IndexError):
                        prompt_idx = result_idx  # Fallback to result_idx

                    short_prompts_results[prompt_idx] = cast(
527
528
                        PoolingRequestOutput, result
                    )
529
530

            # Finalize aggregated results
531
532
533
            final_res_batch: list[
                Union[PoolingRequestOutput, EmbeddingRequestOutput]
            ] = []
534
535
536
537
538
539
540
            num_prompts = len(ctx.engine_prompts)

            for prompt_idx in range(num_prompts):
                if prompt_idx in prompt_aggregators:
                    # Finalize MEAN aggregation for this chunked prompt
                    aggregator = prompt_aggregators[prompt_idx]

541
542
                    weighted_sum = aggregator["weighted_sum"]
                    total_weight = aggregator["total_weight"]
543

544
545
546
547
548
549
                    if (
                        weighted_sum is not None
                        and isinstance(weighted_sum, torch.Tensor)
                        and isinstance(total_weight, (int, float))
                        and total_weight > 0
                    ):
550
551
552
553
554
                        # Compute final mean embedding
                        final_embedding = weighted_sum / total_weight

                        # Create a PoolingRequestOutput
                        # for the aggregated result
555
                        pooling_output_data = PoolingOutput(data=final_embedding)
556
557

                        # Get original prompt token IDs for this prompt
558
                        original_prompt = ctx.engine_prompts[prompt_idx]
559
560
                        if not self._is_text_tokens_prompt(original_prompt):
                            return self.create_error_response(
561
562
                                f"Chunked prompt {prompt_idx} is not a TextTokensPrompt"
                            )
563

564
565
566
                        original_token_ids = cast(TextTokensPrompt, original_prompt)[
                            "prompt_token_ids"
                        ]
567
568

                        pooling_request_output = PoolingRequestOutput(
569
                            request_id=aggregator["request_id"],
570
571
                            prompt_token_ids=original_token_ids,
                            outputs=pooling_output_data,
572
573
                            finished=True,
                        )
574
575
576
577

                        final_res_batch.append(pooling_request_output)
                    else:
                        return self.create_error_response(
578
579
                            f"Failed to aggregate chunks for prompt {prompt_idx}"
                        )
580
581
                elif prompt_idx in short_prompts_results:
                    final_res_batch.append(
582
583
                        cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
                    )
584
585
                else:
                    return self.create_error_response(
586
587
                        f"Result not found for prompt {prompt_idx}"
                    )
588
589

            ctx.final_res_batch = cast(
590
591
                list[Union[RequestOutput, PoolingRequestOutput]], final_res_batch
            )
592
593
594
595
596
597

            return None

        except Exception as e:
            return self.create_error_response(str(e))

598
599
600
601
602
603
604
605
606
607
608
609
610

class OpenAIServingEmbedding(EmbeddingMixin):
    request_id_prefix = "embd"

    def __init__(
        self,
        engine_client: EngineClient,
        model_config: ModelConfig,
        models: OpenAIServingModels,
        *,
        request_logger: Optional[RequestLogger],
        chat_template: Optional[str],
        chat_template_content_format: ChatTemplateContentFormatOption,
611
        trust_request_chat_template: bool = False,
612
        log_error_stack: bool = False,
613
    ) -> None:
614
615
616
617
618
619
620
        super().__init__(
            engine_client=engine_client,
            model_config=model_config,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )
621
622
623

        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
624
        self.trust_request_chat_template = trust_request_chat_template
625
626
627
628
629
630
631
632
633
634
635
636

    async def create_embedding(
        self,
        request: EmbeddingRequest,
        raw_request: Optional[Request] = None,
    ) -> Union[EmbeddingResponse, ErrorResponse]:
        """
        Embedding API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/embeddings/create
        for the API specification. This API mimics the OpenAI Embedding API.
        """
637
        model_name = self.models.model_name()
638
639
        request_id = (
            f"{self.request_id_prefix}-"
640
641
            f"{self._base_request_id(raw_request, request.request_id)}"
        )
642
643
644
645
646
647
648
649
650
651
652
653

        ctx = EmbeddingServeContext(
            request=request,
            raw_request=raw_request,
            model_name=model_name,
            request_id=request_id,
            chat_template=self.chat_template,
            chat_template_content_format=self.chat_template_content_format,
        )

        return await super().handle(ctx)  # type: ignore

654
655
656
657
658
659
660
661
    @override
    def _create_pooling_params(
        self,
        ctx: ServeContext[EmbeddingRequest],
    ) -> Union[PoolingParams, ErrorResponse]:
        pooling_params = super()._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params
662
663

        try:
664
            pooling_params.verify("embed", self.model_config)
665
666
667
        except ValueError as e:
            return self.create_error_response(str(e))

668
        return pooling_params
669
670
671
672
673
674
675
676
677
678
679
680
681
682

    async def _preprocess(
        self,
        ctx: ServeContext,
    ) -> Optional[ErrorResponse]:
        if isinstance(ctx.request, EmbeddingChatRequest):
            error_check_ret = self._validate_chat_template(
                request_chat_template=ctx.request.chat_template,
                chat_template_kwargs=ctx.request.chat_template_kwargs,
                trust_request_chat_template=self.trust_request_chat_template,
            )
            if error_check_ret is not None:
                return error_check_ret
        return await super()._preprocess(ctx)