"vllm/vscode:/vscode.git/clone" did not exist on "b2eb84de77d759b1689bf7e567eeb2ae1a1050d7"
serving_embedding.py 25.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import AsyncGenerator, Mapping
5
from typing import Any, Final, cast
6

7
import torch
8
from fastapi import Request
9
from typing_extensions import override
10

11
from vllm.engine.protocol import EngineClient
12
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
13
from vllm.entrypoints.logger import RequestLogger
14
from vllm.entrypoints.openai.protocol import (
15
    EMBED_DTYPE_TO_TORCH_DTYPE,
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
    EmbeddingResponseData,
    ErrorResponse,
    UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import (
    EmbeddingServeContext,
    OpenAIServing,
    ServeContext,
    TextTokensPrompt,
)
30
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
31
from vllm.entrypoints.openai.utils import encoding_pooling_output
32
from vllm.entrypoints.renderer import RenderConfig
33
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
34
from vllm.logger import init_logger
35
36
37
38
39
40
from vllm.outputs import (
    EmbeddingRequestOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
41
from vllm.pooling_params import PoolingParams
42
from vllm.utils import chunk_list
43
44
45
46

logger = init_logger(__name__)


47
class EmbeddingMixin(OpenAIServing):
48
49
50
51
52
53
54
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        pooler_config = self.model_config.pooler_config

        # Avoid repeated attribute lookups
        self.supports_chunked_processing = bool(
55
56
57
58
59
60
61
            pooler_config and pooler_config.enable_chunked_processing
        )
        self.max_embed_len = (
            pooler_config.max_embed_len
            if pooler_config and pooler_config.max_embed_len
            else None
        )
62

63
    @override
64
    async def _preprocess(
65
        self,
66
        ctx: ServeContext,
67
    ) -> ErrorResponse | None:
68
        ctx = cast(EmbeddingServeContext, ctx)
69
        try:
70
71
72
73
74
75
            if ctx.request.embed_dtype not in EMBED_DTYPE_TO_TORCH_DTYPE:
                return self.create_error_response(
                    f"embed_dtype={ctx.request.embed_dtype!r} is not supported. "
                    f"Supported types: {EMBED_DTYPE_TO_TORCH_DTYPE.keys()}"
                )

76
            ctx.lora_request = self._maybe_get_adapters(ctx.request)
77

78
            tokenizer = await self.engine_client.get_tokenizer()
79
            renderer = self._get_renderer(tokenizer)
80

81
            if isinstance(ctx.request, EmbeddingChatRequest):
82
83
                (
                    _,
84
                    _,
85
                    ctx.engine_prompts,
86
                ) = await self._preprocess_chat(
87
                    ctx.request,
88
                    tokenizer,
89
                    ctx.request.messages,
90
91
                    chat_template=ctx.request.chat_template or ctx.chat_template,
                    chat_template_content_format=ctx.chat_template_content_format,
92
                    add_generation_prompt=ctx.request.add_generation_prompt,
93
                    continue_final_message=False,
94
                    add_special_tokens=ctx.request.add_special_tokens,
95
96
                )
            else:
97
98
                ctx.engine_prompts = await renderer.render_prompt(
                    prompt_or_prompts=ctx.request.input,
99
                    config=self._build_render_config(ctx.request),
100
                )
101
            return None
102
        except (ValueError, TypeError) as e:
103
104
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
105

106
    def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig:
107
108
109
110
111
112
113
114
115
        # Set max_length based on chunked processing capability
        if self._should_use_chunked_processing(request):
            max_length = None
        else:
            max_length = self.max_embed_len or self.max_model_len

        return RenderConfig(
            max_length=max_length,
            truncate_prompt_tokens=request.truncate_prompt_tokens,
116
117
            add_special_tokens=request.add_special_tokens,
        )
118

119
    @override
120
    def _build_response(
121
        self,
122
        ctx: ServeContext,
123
    ) -> EmbeddingResponse | ErrorResponse:
124
        items: list[EmbeddingResponseData] = []
125
126
        num_prompt_tokens = 0

127
        final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
128
129

        for idx, final_res in enumerate(final_res_batch_checked):
130
131
            item = EmbeddingResponseData(
                index=idx,
132
133
                embedding=encoding_pooling_output(
                    final_res, ctx.request.encoding_format, ctx.request.embed_dtype
134
                ),
135
136
137
138
139
140
141
142
143
144
145
146
            )
            prompt_token_ids = final_res.prompt_token_ids

            items.append(item)
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

        return EmbeddingResponse(
147
148
149
            id=ctx.request_id,
            created=ctx.created_time,
            model=ctx.model_name,
150
151
152
            data=items,
            usage=usage,
        )
153

154
155
156
157
158
159
    def _get_max_position_embeddings(self) -> int:
        """Get the model's effective maximum sequence length for chunking."""
        return self.model_config.max_model_len

    def _should_use_chunked_processing(self, request) -> bool:
        """Check if chunked processing should be used for this request."""
160
161
162
163
        return (
            isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest))
            and self.supports_chunked_processing
        )
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

    async def _process_chunked_request(
        self,
        ctx: EmbeddingServeContext,
        original_prompt: TextTokensPrompt,
        pooling_params,
        trace_headers,
        prompt_idx: int,
    ) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
        """Process a single prompt using chunked processing."""
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
        token_ids = original_prompt["prompt_token_ids"]

        # Split into chunks using max_position_embeddings
        max_pos_embeddings = self._get_max_position_embeddings()
        # Process all chunks for MEAN aggregation
        for chunk_idx, chunk_tokens in enumerate(
181
182
            chunk_list(token_ids, max_pos_embeddings)
        ):
183
            # Create a request ID for this chunk
184
            chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
185
186

            # Create engine prompt for this chunk
187
            chunk_engine_prompt = EngineTokensPrompt(prompt_token_ids=chunk_tokens)
188
189
190
191

            # Create chunk request prompt for logging
            chunk_text = ""
            chunk_request_prompt = TextTokensPrompt(
192
193
                prompt=chunk_text, prompt_token_ids=chunk_tokens
            )
194
195

            # Log the chunk
196
197
198
199
200
201
            self._log_inputs(
                chunk_request_id,
                chunk_request_prompt,
                params=pooling_params,
                lora_request=ctx.lora_request,
            )
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

            # Create generator for this chunk and wrap it to return indices
            original_generator = self.engine_client.encode(
                chunk_engine_prompt,
                pooling_params,
                chunk_request_id,
                lora_request=ctx.lora_request,
                trace_headers=trace_headers,
                priority=getattr(ctx.request, "priority", 0),
            )

            generators.append(original_generator)

        return generators

    def _validate_input(
        self,
        request,
        input_ids: list[int],
        input_text: str,
    ) -> TextTokensPrompt:
        """Override to support chunked processing for embedding requests."""
        token_num = len(input_ids)

        # Note: EmbeddingRequest doesn't have max_tokens
227
        if isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest)):
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
            # Check if chunked processing is enabled for pooling models
            enable_chunked = self._should_use_chunked_processing(request)

            # Use max_position_embeddings for chunked processing decisions
            max_pos_embeddings = self._get_max_position_embeddings()

            # Determine the effective max length for validation
            if self.max_embed_len is not None:
                # Use max_embed_len for validation instead of max_model_len
                length_type = "maximum embedding input length"
                max_length_value = self.max_embed_len
            else:
                # Fall back to max_model_len validation (original behavior)
                length_type = "maximum context length"
                max_length_value = self.max_model_len

            validation_error_msg = (
                "This model's {length_type} is {max_length_value} tokens. "
                "However, you requested {token_num} tokens in the input for "
247
248
                "embedding generation. Please reduce the length of the input."
            )
249
250
251
252
253

            chunked_processing_error_msg = (
                "This model's {length_type} is {max_length_value} tokens. "
                "However, you requested {token_num} tokens in the input for "
                "embedding generation. Please reduce the length of the input "
254
255
                "or enable chunked processing."
            )
256
257
258
259
260
261
262

            # Check if input exceeds max length
            if token_num > max_length_value:
                raise ValueError(
                    validation_error_msg.format(
                        length_type=length_type,
                        max_length_value=max_length_value,
263
264
265
                        token_num=token_num,
                    )
                )
266
267
268
269
270
271
272
273

            # Check for chunked processing
            # when exceeding max_position_embeddings
            if token_num > max_pos_embeddings:
                if enable_chunked:
                    # Allow long inputs when chunked processing is enabled
                    logger.info(
                        "Input length %s exceeds max_position_embeddings "
274
275
276
277
                        "%s, will use chunked processing",
                        token_num,
                        max_pos_embeddings,
                    )
278
279
280
281
282
                else:
                    raise ValueError(
                        chunked_processing_error_msg.format(
                            length_type="maximum position embeddings length",
                            max_length_value=max_pos_embeddings,
283
284
285
                            token_num=token_num,
                        )
                    )
286

287
            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
288
289
290
291
292
293

        # For other request types, use the parent's implementation
        return super()._validate_input(request, input_ids, input_text)

    def _is_text_tokens_prompt(self, prompt) -> bool:
        """Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
294
295
296
297
298
        return (
            isinstance(prompt, dict)
            and "prompt_token_ids" in prompt
            and "prompt_embeds" not in prompt
        )
299
300
301
302

    async def _create_single_prompt_generator(
        self,
        ctx: EmbeddingServeContext,
303
        engine_prompt: EngineTokensPrompt,
304
        pooling_params: PoolingParams,
305
        trace_headers: Mapping[str, str] | None,
306
        prompt_index: int,
307
    ) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]:
308
309
310
        """Create a generator for a single prompt using standard processing."""
        request_id_item = f"{ctx.request_id}-{prompt_index}"

311
312
313
314
315
316
        self._log_inputs(
            request_id_item,
            engine_prompt,
            params=pooling_params,
            lora_request=ctx.lora_request,
        )
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331

        # Return the original generator without wrapping
        return self.engine_client.encode(
            engine_prompt,
            pooling_params,
            request_id_item,
            lora_request=ctx.lora_request,
            trace_headers=trace_headers,
            priority=getattr(ctx.request, "priority", 0),
        )

    @override
    async def _prepare_generators(
        self,
        ctx: ServeContext,
332
    ) -> ErrorResponse | None:
333
334
335
336
337
338
339
340
341
342
343
        """Override to support chunked processing."""
        ctx = cast(EmbeddingServeContext, ctx)

        # Check if we should use chunked processing
        use_chunked = self._should_use_chunked_processing(ctx.request)

        # If no chunked processing needed, delegate to parent class
        if not use_chunked:
            return await super()._prepare_generators(ctx)

        # Custom logic for chunked processing
344
        generators: list[
345
            AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
346
        ] = []
347
348

        try:
349
350
351
352
353
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
354
355
356
357
358
359
360
361
362
363
364
365

            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params

            # Verify and set the task for pooling params
            try:
                pooling_params.verify("embed", self.model_config)
            except ValueError as e:
                return self.create_error_response(str(e))

            if ctx.engine_prompts is None:
366
                return self.create_error_response("Engine prompts not available")
367
368
369
370
371

            max_pos_embeddings = self._get_max_position_embeddings()

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                # Check if this specific prompt needs chunked processing
372
                if self._is_text_tokens_prompt(engine_prompt):
373
374
                    # Cast to TextTokensPrompt since we've verified
                    # prompt_token_ids
375
                    text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
376
                    if len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings:
377
378
                        # Use chunked processing for this prompt
                        chunk_generators = await self._process_chunked_request(
379
380
                            ctx, text_tokens_prompt, pooling_params, trace_headers, i
                        )
381
382
383
384
385
                        generators.extend(chunk_generators)
                        continue

                # Normal processing for short prompts or non-token prompts
                generator = await self._create_single_prompt_generator(
386
387
                    ctx, engine_prompt, pooling_params, trace_headers, i
                )
388
389
390
                generators.append(generator)

            from vllm.utils import merge_async_iterators
391

392
393
394
395
396
397
398
399
400
401
402
403
            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    @override
    async def _collect_batch(
        self,
        ctx: ServeContext,
404
    ) -> ErrorResponse | None:
405
406
        """Collect and aggregate batch results
        with support for chunked processing.
407
408

        For chunked requests, performs online aggregation to
409
410
411
412
413
414
        minimize memory usage.
        For regular requests, collects results normally.
        """
        ctx = cast(EmbeddingServeContext, ctx)
        try:
            if ctx.engine_prompts is None:
415
                return self.create_error_response("Engine prompts not available")
416
417
418
419
420
421
422
423

            # Check if we used chunked processing
            use_chunked = self._should_use_chunked_processing(ctx.request)

            if not use_chunked:
                return await super()._collect_batch(ctx=ctx)

            if ctx.result_generator is None:
424
                return self.create_error_response("Result generator not available")
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444

            # Online aggregation for chunked requests to
            # minimize memory usage
            # Track aggregation state for each prompt
            prompt_aggregators: dict[int, dict[str, Any]] = {}
            short_prompts_results: dict[int, PoolingRequestOutput] = {}

            async for result_idx, result in ctx.result_generator:
                if "-chunk-" in result.request_id:
                    # Extract prompt_idx from chunked request_id
                    parts = result.request_id.split("-")
                    try:
                        prompt_idx = int(parts[parts.index("prompt") + 1])
                    except (ValueError, IndexError):
                        # Fallback: extract from result_idx if parsing fails
                        prompt_idx = result_idx

                    # Initialize aggregator for this prompt if needed
                    if prompt_idx not in prompt_aggregators:
                        prompt_aggregators[prompt_idx] = {
445
446
447
448
                            "weighted_sum": None,
                            "total_weight": 0,
                            "chunk_count": 0,
                            "request_id": result.request_id.split("-chunk-")[0],
449
450
451
452
453
454
455
456
457
458
459
                        }

                    aggregator = prompt_aggregators[prompt_idx]

                    # MEAN pooling with online weighted averaging
                    # Ensure result is PoolingRequestOutput
                    # for embedding processing
                    if not isinstance(result, PoolingRequestOutput):
                        return self.create_error_response(
                            f"Expected PoolingRequestOutput for "
                            f"chunked embedding, got "
460
461
                            f"{type(result).__name__}"
                        )
462
463
464

                    # Handle both PoolingOutput and
                    # EmbeddingOutput types
465
                    if hasattr(result.outputs, "data"):
466
467
                        # PoolingOutput case
                        embedding_data = result.outputs.data
468
                    elif hasattr(result.outputs, "embedding"):
469
470
471
472
473
                        # EmbeddingOutput case -
                        # convert embedding list to tensor
                        embedding_data = result.outputs.embedding
                    else:
                        return self.create_error_response(
474
475
                            f"Unsupported output type: {type(result.outputs).__name__}"
                        )
476
477

                    if not isinstance(embedding_data, torch.Tensor):
478
479
480
                        embedding_data = torch.tensor(
                            embedding_data, dtype=torch.float32
                        )
481
482
483

                    if result.prompt_token_ids is None:
                        return self.create_error_response(
484
485
                            "prompt_token_ids cannot be None for chunked processing"
                        )
486
487
                    weight = len(result.prompt_token_ids)

488
                    weighted_embedding = embedding_data.to(dtype=torch.float32) * weight
489

490
                    if aggregator["weighted_sum"] is None:
491
                        # First chunk
492
                        aggregator["weighted_sum"] = weighted_embedding
493
494
                    else:
                        # Accumulate
495
                        aggregator["weighted_sum"] += weighted_embedding
496

497
498
                    aggregator["total_weight"] += weight
                    aggregator["chunk_count"] += 1
499
500
501
502
503
504
505
506
507
508
                else:
                    # Non-chunked result - extract prompt_idx from request_id
                    parts = result.request_id.split("-")
                    try:
                        # Last part should be prompt index
                        prompt_idx = int(parts[-1])
                    except (ValueError, IndexError):
                        prompt_idx = result_idx  # Fallback to result_idx

                    short_prompts_results[prompt_idx] = cast(
509
510
                        PoolingRequestOutput, result
                    )
511
512

            # Finalize aggregated results
513
            final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = []
514
515
516
517
518
519
520
            num_prompts = len(ctx.engine_prompts)

            for prompt_idx in range(num_prompts):
                if prompt_idx in prompt_aggregators:
                    # Finalize MEAN aggregation for this chunked prompt
                    aggregator = prompt_aggregators[prompt_idx]

521
522
                    weighted_sum = aggregator["weighted_sum"]
                    total_weight = aggregator["total_weight"]
523

524
525
526
527
528
529
                    if (
                        weighted_sum is not None
                        and isinstance(weighted_sum, torch.Tensor)
                        and isinstance(total_weight, (int, float))
                        and total_weight > 0
                    ):
530
531
532
533
534
                        # Compute final mean embedding
                        final_embedding = weighted_sum / total_weight

                        # Create a PoolingRequestOutput
                        # for the aggregated result
535
                        pooling_output_data = PoolingOutput(data=final_embedding)
536
537

                        # Get original prompt token IDs for this prompt
538
                        original_prompt = ctx.engine_prompts[prompt_idx]
539
540
                        if not self._is_text_tokens_prompt(original_prompt):
                            return self.create_error_response(
541
542
                                f"Chunked prompt {prompt_idx} is not a TextTokensPrompt"
                            )
543

544
545
546
                        original_token_ids = cast(TextTokensPrompt, original_prompt)[
                            "prompt_token_ids"
                        ]
547
548

                        pooling_request_output = PoolingRequestOutput(
549
                            request_id=aggregator["request_id"],
550
551
                            prompt_token_ids=original_token_ids,
                            outputs=pooling_output_data,
552
553
                            finished=True,
                        )
554
555
556
557

                        final_res_batch.append(pooling_request_output)
                    else:
                        return self.create_error_response(
558
559
                            f"Failed to aggregate chunks for prompt {prompt_idx}"
                        )
560
561
                elif prompt_idx in short_prompts_results:
                    final_res_batch.append(
562
563
                        cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
                    )
564
565
                else:
                    return self.create_error_response(
566
567
                        f"Result not found for prompt {prompt_idx}"
                    )
568
569

            ctx.final_res_batch = cast(
570
                list[RequestOutput | PoolingRequestOutput], final_res_batch
571
            )
572
573
574
575
576
577

            return None

        except Exception as e:
            return self.create_error_response(str(e))

578
579
580
581
582
583
584
585
586

class OpenAIServingEmbedding(EmbeddingMixin):
    request_id_prefix = "embd"

    def __init__(
        self,
        engine_client: EngineClient,
        models: OpenAIServingModels,
        *,
587
588
        request_logger: RequestLogger | None,
        chat_template: str | None,
589
        chat_template_content_format: ChatTemplateContentFormatOption,
590
        trust_request_chat_template: bool = False,
591
        log_error_stack: bool = False,
592
    ) -> None:
593
594
595
596
597
598
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )
599
600
601

        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
602
        self.trust_request_chat_template = trust_request_chat_template
603
604
605
606

    async def create_embedding(
        self,
        request: EmbeddingRequest,
607
608
        raw_request: Request | None = None,
    ) -> EmbeddingResponse | ErrorResponse:
609
610
611
612
613
614
        """
        Embedding API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/embeddings/create
        for the API specification. This API mimics the OpenAI Embedding API.
        """
615
        model_name = self.models.model_name()
616
617
        request_id = (
            f"{self.request_id_prefix}-"
618
619
            f"{self._base_request_id(raw_request, request.request_id)}"
        )
620
621
622
623
624
625
626
627
628
629
630
631

        ctx = EmbeddingServeContext(
            request=request,
            raw_request=raw_request,
            model_name=model_name,
            request_id=request_id,
            chat_template=self.chat_template,
            chat_template_content_format=self.chat_template_content_format,
        )

        return await super().handle(ctx)  # type: ignore

632
633
634
635
    @override
    def _create_pooling_params(
        self,
        ctx: ServeContext[EmbeddingRequest],
636
    ) -> PoolingParams | ErrorResponse:
637
638
639
        pooling_params = super()._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params
640
641

        try:
642
            pooling_params.verify("embed", self.model_config)
643
644
645
        except ValueError as e:
            return self.create_error_response(str(e))

646
        return pooling_params
647
648
649
650

    async def _preprocess(
        self,
        ctx: ServeContext,
651
    ) -> ErrorResponse | None:
652
653
654
655
656
657
658
659
660
        if isinstance(ctx.request, EmbeddingChatRequest):
            error_check_ret = self._validate_chat_template(
                request_chat_template=ctx.request.chat_template,
                chat_template_kwargs=ctx.request.chat_template_kwargs,
                trust_request_chat_template=self.trust_request_chat_template,
            )
            if error_check_ret is not None:
                return error_check_ret
        return await super()._preprocess(ctx)