serving_embedding.py 25.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import base64
5
from collections.abc import AsyncGenerator, Mapping
6
from typing import Any, Final, Literal, cast
7

8
import numpy as np
9
import torch
10
from fastapi import Request
11
from typing_extensions import assert_never, override
12

13
from vllm.engine.protocol import EngineClient
14
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
15
from vllm.entrypoints.logger import RequestLogger
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from vllm.entrypoints.openai.protocol import (
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
    EmbeddingResponseData,
    ErrorResponse,
    UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import (
    EmbeddingServeContext,
    OpenAIServing,
    ServeContext,
    TextTokensPrompt,
)
31
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
32
from vllm.entrypoints.renderer import RenderConfig
33
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
34
from vllm.logger import init_logger
35
36
37
38
39
40
41
from vllm.outputs import (
    EmbeddingOutput,
    EmbeddingRequestOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
42
from vllm.pooling_params import PoolingParams
43
from vllm.utils import chunk_list
44
45
46
47

logger = init_logger(__name__)


48
def _get_embedding(
49
    output: EmbeddingOutput,
50
    encoding_format: Literal["float", "base64"],
51
) -> list[float] | str:
52
53
54
    if encoding_format == "float":
        return output.embedding
    elif encoding_format == "base64":
55
56
57
        # Force to use float32 for base64 encoding
        # to match the OpenAI python client behavior
        embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
58
59
60
61
62
        return base64.b64encode(embedding_bytes).decode("utf-8")

    assert_never(encoding_format)


63
class EmbeddingMixin(OpenAIServing):
64
65
66
67
68
69
70
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        pooler_config = self.model_config.pooler_config

        # Avoid repeated attribute lookups
        self.supports_chunked_processing = bool(
71
72
73
74
75
76
77
            pooler_config and pooler_config.enable_chunked_processing
        )
        self.max_embed_len = (
            pooler_config.max_embed_len
            if pooler_config and pooler_config.max_embed_len
            else None
        )
78

79
    @override
80
    async def _preprocess(
81
        self,
82
        ctx: ServeContext,
83
    ) -> ErrorResponse | None:
84
        ctx = cast(EmbeddingServeContext, ctx)
85
        try:
86
            ctx.lora_request = self._maybe_get_adapters(ctx.request)
87

88
            tokenizer = await self.engine_client.get_tokenizer()
89
            renderer = self._get_renderer(tokenizer)
90

91
            if isinstance(ctx.request, EmbeddingChatRequest):
92
93
                (
                    _,
94
                    _,
95
                    ctx.engine_prompts,
96
                ) = await self._preprocess_chat(
97
                    ctx.request,
98
                    tokenizer,
99
                    ctx.request.messages,
100
101
                    chat_template=ctx.request.chat_template or ctx.chat_template,
                    chat_template_content_format=ctx.chat_template_content_format,
102
                    add_generation_prompt=ctx.request.add_generation_prompt,
103
                    continue_final_message=False,
104
                    add_special_tokens=ctx.request.add_special_tokens,
105
106
                )
            else:
107
108
                ctx.engine_prompts = await renderer.render_prompt(
                    prompt_or_prompts=ctx.request.input,
109
                    config=self._build_render_config(ctx.request),
110
                )
111
            return None
112
        except (ValueError, TypeError) as e:
113
114
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
115

116
    def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig:
117
118
119
120
121
122
123
124
125
        # Set max_length based on chunked processing capability
        if self._should_use_chunked_processing(request):
            max_length = None
        else:
            max_length = self.max_embed_len or self.max_model_len

        return RenderConfig(
            max_length=max_length,
            truncate_prompt_tokens=request.truncate_prompt_tokens,
126
127
            add_special_tokens=request.add_special_tokens,
        )
128

129
    @override
130
    def _build_response(
131
        self,
132
        ctx: ServeContext,
133
    ) -> EmbeddingResponse | ErrorResponse:
134
        items: list[EmbeddingResponseData] = []
135
136
        num_prompt_tokens = 0

137
        final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
138
139

        for idx, final_res in enumerate(final_res_batch_checked):
140
141
142
143
            embedding_res = EmbeddingRequestOutput.from_base(final_res)

            item = EmbeddingResponseData(
                index=idx,
144
145
146
                embedding=_get_embedding(
                    embedding_res.outputs, ctx.request.encoding_format
                ),
147
148
149
150
151
152
153
154
155
156
157
158
            )
            prompt_token_ids = final_res.prompt_token_ids

            items.append(item)
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

        return EmbeddingResponse(
159
160
161
            id=ctx.request_id,
            created=ctx.created_time,
            model=ctx.model_name,
162
163
164
            data=items,
            usage=usage,
        )
165

166
167
168
169
170
171
    def _get_max_position_embeddings(self) -> int:
        """Get the model's effective maximum sequence length for chunking."""
        return self.model_config.max_model_len

    def _should_use_chunked_processing(self, request) -> bool:
        """Check if chunked processing should be used for this request."""
172
173
174
175
        return (
            isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest))
            and self.supports_chunked_processing
        )
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

    async def _process_chunked_request(
        self,
        ctx: EmbeddingServeContext,
        original_prompt: TextTokensPrompt,
        pooling_params,
        trace_headers,
        prompt_idx: int,
    ) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
        """Process a single prompt using chunked processing."""
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
        token_ids = original_prompt["prompt_token_ids"]

        # Split into chunks using max_position_embeddings
        max_pos_embeddings = self._get_max_position_embeddings()
        # Process all chunks for MEAN aggregation
        for chunk_idx, chunk_tokens in enumerate(
193
194
            chunk_list(token_ids, max_pos_embeddings)
        ):
195
            # Create a request ID for this chunk
196
            chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
197
198

            # Create engine prompt for this chunk
199
            chunk_engine_prompt = EngineTokensPrompt(prompt_token_ids=chunk_tokens)
200
201
202
203

            # Create chunk request prompt for logging
            chunk_text = ""
            chunk_request_prompt = TextTokensPrompt(
204
205
                prompt=chunk_text, prompt_token_ids=chunk_tokens
            )
206
207

            # Log the chunk
208
209
210
211
212
213
            self._log_inputs(
                chunk_request_id,
                chunk_request_prompt,
                params=pooling_params,
                lora_request=ctx.lora_request,
            )
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238

            # Create generator for this chunk and wrap it to return indices
            original_generator = self.engine_client.encode(
                chunk_engine_prompt,
                pooling_params,
                chunk_request_id,
                lora_request=ctx.lora_request,
                trace_headers=trace_headers,
                priority=getattr(ctx.request, "priority", 0),
            )

            generators.append(original_generator)

        return generators

    def _validate_input(
        self,
        request,
        input_ids: list[int],
        input_text: str,
    ) -> TextTokensPrompt:
        """Override to support chunked processing for embedding requests."""
        token_num = len(input_ids)

        # Note: EmbeddingRequest doesn't have max_tokens
239
        if isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest)):
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            # Check if chunked processing is enabled for pooling models
            enable_chunked = self._should_use_chunked_processing(request)

            # Use max_position_embeddings for chunked processing decisions
            max_pos_embeddings = self._get_max_position_embeddings()

            # Determine the effective max length for validation
            if self.max_embed_len is not None:
                # Use max_embed_len for validation instead of max_model_len
                length_type = "maximum embedding input length"
                max_length_value = self.max_embed_len
            else:
                # Fall back to max_model_len validation (original behavior)
                length_type = "maximum context length"
                max_length_value = self.max_model_len

            validation_error_msg = (
                "This model's {length_type} is {max_length_value} tokens. "
                "However, you requested {token_num} tokens in the input for "
259
260
                "embedding generation. Please reduce the length of the input."
            )
261
262
263
264
265

            chunked_processing_error_msg = (
                "This model's {length_type} is {max_length_value} tokens. "
                "However, you requested {token_num} tokens in the input for "
                "embedding generation. Please reduce the length of the input "
266
267
                "or enable chunked processing."
            )
268
269
270
271
272
273
274

            # Check if input exceeds max length
            if token_num > max_length_value:
                raise ValueError(
                    validation_error_msg.format(
                        length_type=length_type,
                        max_length_value=max_length_value,
275
276
277
                        token_num=token_num,
                    )
                )
278
279
280
281
282
283
284
285

            # Check for chunked processing
            # when exceeding max_position_embeddings
            if token_num > max_pos_embeddings:
                if enable_chunked:
                    # Allow long inputs when chunked processing is enabled
                    logger.info(
                        "Input length %s exceeds max_position_embeddings "
286
287
288
289
                        "%s, will use chunked processing",
                        token_num,
                        max_pos_embeddings,
                    )
290
291
292
293
294
                else:
                    raise ValueError(
                        chunked_processing_error_msg.format(
                            length_type="maximum position embeddings length",
                            max_length_value=max_pos_embeddings,
295
296
297
                            token_num=token_num,
                        )
                    )
298

299
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
300
301
302
303
304
305

        # For other request types, use the parent's implementation
        return super()._validate_input(request, input_ids, input_text)

    def _is_text_tokens_prompt(self, prompt) -> bool:
        """Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
306
307
308
309
310
        return (
            isinstance(prompt, dict)
            and "prompt_token_ids" in prompt
            and "prompt_embeds" not in prompt
        )
311
312
313
314

    async def _create_single_prompt_generator(
        self,
        ctx: EmbeddingServeContext,
315
        engine_prompt: EngineTokensPrompt,
316
        pooling_params: PoolingParams,
317
        trace_headers: Mapping[str, str] | None,
318
        prompt_index: int,
319
    ) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]:
320
321
322
        """Create a generator for a single prompt using standard processing."""
        request_id_item = f"{ctx.request_id}-{prompt_index}"

323
324
325
326
327
328
        self._log_inputs(
            request_id_item,
            engine_prompt,
            params=pooling_params,
            lora_request=ctx.lora_request,
        )
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

        # Return the original generator without wrapping
        return self.engine_client.encode(
            engine_prompt,
            pooling_params,
            request_id_item,
            lora_request=ctx.lora_request,
            trace_headers=trace_headers,
            priority=getattr(ctx.request, "priority", 0),
        )

    @override
    async def _prepare_generators(
        self,
        ctx: ServeContext,
344
    ) -> ErrorResponse | None:
345
346
347
348
349
350
351
352
353
354
355
        """Override to support chunked processing."""
        ctx = cast(EmbeddingServeContext, ctx)

        # Check if we should use chunked processing
        use_chunked = self._should_use_chunked_processing(ctx.request)

        # If no chunked processing needed, delegate to parent class
        if not use_chunked:
            return await super()._prepare_generators(ctx)

        # Custom logic for chunked processing
356
        generators: list[
357
            AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
358
        ] = []
359
360

        try:
361
362
363
364
365
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
366
367
368
369
370
371
372
373
374
375
376
377

            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params

            # Verify and set the task for pooling params
            try:
                pooling_params.verify("embed", self.model_config)
            except ValueError as e:
                return self.create_error_response(str(e))

            if ctx.engine_prompts is None:
378
                return self.create_error_response("Engine prompts not available")
379
380
381
382
383

            max_pos_embeddings = self._get_max_position_embeddings()

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                # Check if this specific prompt needs chunked processing
384
                if self._is_text_tokens_prompt(engine_prompt):
385
386
                    # Cast to TextTokensPrompt since we've verified
                    # prompt_token_ids
387
                    text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
388
                    if len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings:
389
390
                        # Use chunked processing for this prompt
                        chunk_generators = await self._process_chunked_request(
391
392
                            ctx, text_tokens_prompt, pooling_params, trace_headers, i
                        )
393
394
395
396
397
                        generators.extend(chunk_generators)
                        continue

                # Normal processing for short prompts or non-token prompts
                generator = await self._create_single_prompt_generator(
398
399
                    ctx, engine_prompt, pooling_params, trace_headers, i
                )
400
401
402
                generators.append(generator)

            from vllm.utils import merge_async_iterators
403

404
405
406
407
408
409
410
411
412
413
414
415
            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    @override
    async def _collect_batch(
        self,
        ctx: ServeContext,
416
    ) -> ErrorResponse | None:
417
418
        """Collect and aggregate batch results
        with support for chunked processing.
419
420

        For chunked requests, performs online aggregation to
421
422
423
424
425
426
        minimize memory usage.
        For regular requests, collects results normally.
        """
        ctx = cast(EmbeddingServeContext, ctx)
        try:
            if ctx.engine_prompts is None:
427
                return self.create_error_response("Engine prompts not available")
428
429
430
431
432
433
434
435

            # Check if we used chunked processing
            use_chunked = self._should_use_chunked_processing(ctx.request)

            if not use_chunked:
                return await super()._collect_batch(ctx=ctx)

            if ctx.result_generator is None:
436
                return self.create_error_response("Result generator not available")
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456

            # Online aggregation for chunked requests to
            # minimize memory usage
            # Track aggregation state for each prompt
            prompt_aggregators: dict[int, dict[str, Any]] = {}
            short_prompts_results: dict[int, PoolingRequestOutput] = {}

            async for result_idx, result in ctx.result_generator:
                if "-chunk-" in result.request_id:
                    # Extract prompt_idx from chunked request_id
                    parts = result.request_id.split("-")
                    try:
                        prompt_idx = int(parts[parts.index("prompt") + 1])
                    except (ValueError, IndexError):
                        # Fallback: extract from result_idx if parsing fails
                        prompt_idx = result_idx

                    # Initialize aggregator for this prompt if needed
                    if prompt_idx not in prompt_aggregators:
                        prompt_aggregators[prompt_idx] = {
457
458
459
460
                            "weighted_sum": None,
                            "total_weight": 0,
                            "chunk_count": 0,
                            "request_id": result.request_id.split("-chunk-")[0],
461
462
463
464
465
466
467
468
469
470
471
                        }

                    aggregator = prompt_aggregators[prompt_idx]

                    # MEAN pooling with online weighted averaging
                    # Ensure result is PoolingRequestOutput
                    # for embedding processing
                    if not isinstance(result, PoolingRequestOutput):
                        return self.create_error_response(
                            f"Expected PoolingRequestOutput for "
                            f"chunked embedding, got "
472
473
                            f"{type(result).__name__}"
                        )
474
475
476

                    # Handle both PoolingOutput and
                    # EmbeddingOutput types
477
                    if hasattr(result.outputs, "data"):
478
479
                        # PoolingOutput case
                        embedding_data = result.outputs.data
480
                    elif hasattr(result.outputs, "embedding"):
481
482
483
484
485
                        # EmbeddingOutput case -
                        # convert embedding list to tensor
                        embedding_data = result.outputs.embedding
                    else:
                        return self.create_error_response(
486
487
                            f"Unsupported output type: {type(result.outputs).__name__}"
                        )
488
489

                    if not isinstance(embedding_data, torch.Tensor):
490
491
492
                        embedding_data = torch.tensor(
                            embedding_data, dtype=torch.float32
                        )
493
494
495

                    if result.prompt_token_ids is None:
                        return self.create_error_response(
496
497
                            "prompt_token_ids cannot be None for chunked processing"
                        )
498
499
                    weight = len(result.prompt_token_ids)

500
                    weighted_embedding = embedding_data.to(dtype=torch.float32) * weight
501

502
                    if aggregator["weighted_sum"] is None:
503
                        # First chunk
504
                        aggregator["weighted_sum"] = weighted_embedding
505
506
                    else:
                        # Accumulate
507
                        aggregator["weighted_sum"] += weighted_embedding
508

509
510
                    aggregator["total_weight"] += weight
                    aggregator["chunk_count"] += 1
511
512
513
514
515
516
517
518
519
520
                else:
                    # Non-chunked result - extract prompt_idx from request_id
                    parts = result.request_id.split("-")
                    try:
                        # Last part should be prompt index
                        prompt_idx = int(parts[-1])
                    except (ValueError, IndexError):
                        prompt_idx = result_idx  # Fallback to result_idx

                    short_prompts_results[prompt_idx] = cast(
521
522
                        PoolingRequestOutput, result
                    )
523
524

            # Finalize aggregated results
525
            final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = []
526
527
528
529
530
531
532
            num_prompts = len(ctx.engine_prompts)

            for prompt_idx in range(num_prompts):
                if prompt_idx in prompt_aggregators:
                    # Finalize MEAN aggregation for this chunked prompt
                    aggregator = prompt_aggregators[prompt_idx]

533
534
                    weighted_sum = aggregator["weighted_sum"]
                    total_weight = aggregator["total_weight"]
535

536
537
538
539
540
541
                    if (
                        weighted_sum is not None
                        and isinstance(weighted_sum, torch.Tensor)
                        and isinstance(total_weight, (int, float))
                        and total_weight > 0
                    ):
542
543
544
545
546
                        # Compute final mean embedding
                        final_embedding = weighted_sum / total_weight

                        # Create a PoolingRequestOutput
                        # for the aggregated result
547
                        pooling_output_data = PoolingOutput(data=final_embedding)
548
549

                        # Get original prompt token IDs for this prompt
550
                        original_prompt = ctx.engine_prompts[prompt_idx]
551
552
                        if not self._is_text_tokens_prompt(original_prompt):
                            return self.create_error_response(
553
554
                                f"Chunked prompt {prompt_idx} is not a TextTokensPrompt"
                            )
555

556
557
558
                        original_token_ids = cast(TextTokensPrompt, original_prompt)[
                            "prompt_token_ids"
                        ]
559
560

                        pooling_request_output = PoolingRequestOutput(
561
                            request_id=aggregator["request_id"],
562
563
                            prompt_token_ids=original_token_ids,
                            outputs=pooling_output_data,
564
565
                            finished=True,
                        )
566
567
568
569

                        final_res_batch.append(pooling_request_output)
                    else:
                        return self.create_error_response(
570
571
                            f"Failed to aggregate chunks for prompt {prompt_idx}"
                        )
572
573
                elif prompt_idx in short_prompts_results:
                    final_res_batch.append(
574
575
                        cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
                    )
576
577
                else:
                    return self.create_error_response(
578
579
                        f"Result not found for prompt {prompt_idx}"
                    )
580
581

            ctx.final_res_batch = cast(
582
                list[RequestOutput | PoolingRequestOutput], final_res_batch
583
            )
584
585
586
587
588
589

            return None

        except Exception as e:
            return self.create_error_response(str(e))

590
591
592
593
594
595
596
597
598

class OpenAIServingEmbedding(EmbeddingMixin):
    request_id_prefix = "embd"

    def __init__(
        self,
        engine_client: EngineClient,
        models: OpenAIServingModels,
        *,
599
600
        request_logger: RequestLogger | None,
        chat_template: str | None,
601
        chat_template_content_format: ChatTemplateContentFormatOption,
602
        trust_request_chat_template: bool = False,
603
        log_error_stack: bool = False,
604
    ) -> None:
605
606
607
608
609
610
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )
611
612
613

        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
614
        self.trust_request_chat_template = trust_request_chat_template
615
616
617
618

    async def create_embedding(
        self,
        request: EmbeddingRequest,
619
620
        raw_request: Request | None = None,
    ) -> EmbeddingResponse | ErrorResponse:
621
622
623
624
625
626
        """
        Embedding API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/embeddings/create
        for the API specification. This API mimics the OpenAI Embedding API.
        """
627
        model_name = self.models.model_name()
628
629
        request_id = (
            f"{self.request_id_prefix}-"
630
631
            f"{self._base_request_id(raw_request, request.request_id)}"
        )
632
633
634
635
636
637
638
639
640
641
642
643

        ctx = EmbeddingServeContext(
            request=request,
            raw_request=raw_request,
            model_name=model_name,
            request_id=request_id,
            chat_template=self.chat_template,
            chat_template_content_format=self.chat_template_content_format,
        )

        return await super().handle(ctx)  # type: ignore

644
645
646
647
    @override
    def _create_pooling_params(
        self,
        ctx: ServeContext[EmbeddingRequest],
648
    ) -> PoolingParams | ErrorResponse:
649
650
651
        pooling_params = super()._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params
652
653

        try:
654
            pooling_params.verify("embed", self.model_config)
655
656
657
        except ValueError as e:
            return self.create_error_response(str(e))

658
        return pooling_params
659
660
661
662

    async def _preprocess(
        self,
        ctx: ServeContext,
663
    ) -> ErrorResponse | None:
664
665
666
667
668
669
670
671
672
        if isinstance(ctx.request, EmbeddingChatRequest):
            error_check_ret = self._validate_chat_template(
                request_chat_template=ctx.request.chat_template,
                chat_template_kwargs=ctx.request.chat_template_kwargs,
                trust_request_chat_template=self.trust_request_chat_template,
            )
            if error_check_ret is not None:
                return error_check_ret
        return await super()._preprocess(ctx)