serving.py 24.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
from collections.abc import AsyncGenerator, Mapping
5
from typing import Any, Final, TypeAlias
6

7
import torch
8
from fastapi import Request
9
from typing_extensions import assert_never
10

11
from vllm.engine.protocol import EngineClient
12
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
13
from vllm.entrypoints.logger import RequestLogger
14
15
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
16
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
17
18
19
20
21
22
23
24
from vllm.entrypoints.pooling.embed.protocol import (
    EmbeddingBytesResponse,
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
    EmbeddingResponseData,
)
25
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
26
from vllm.logger import init_logger
27
from vllm.outputs import PoolingOutput, PoolingRequestOutput
28
from vllm.pooling_params import PoolingParams
29
30
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list
31
32
33
34
from vllm.utils.serial_utils import (
    encode_pooling_bytes,
    encode_pooling_output,
)
35
36
37
38

logger = init_logger(__name__)


39
EmbeddingServeContext: TypeAlias = ServeContext[EmbeddingRequest]
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65


class OpenAIServingEmbedding(OpenAIServing):
    request_id_prefix = "embd"

    def __init__(
        self,
        engine_client: EngineClient,
        models: OpenAIServingModels,
        *,
        request_logger: RequestLogger | None,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
        trust_request_chat_template: bool = False,
        log_error_stack: bool = False,
    ) -> None:
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )

        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
        self.trust_request_chat_template = trust_request_chat_template
66
67
68
69
70

        pooler_config = self.model_config.pooler_config

        # Avoid repeated attribute lookups
        self.supports_chunked_processing = bool(
71
72
73
74
75
76
77
            pooler_config and pooler_config.enable_chunked_processing
        )
        self.max_embed_len = (
            pooler_config.max_embed_len
            if pooler_config and pooler_config.max_embed_len
            else None
        )
78

79
    async def _preprocess(
80
        self,
81
        ctx: EmbeddingServeContext,
82
    ) -> ErrorResponse | None:
83
        try:
84
            ctx.lora_request = self._maybe_get_adapters(ctx.request)
85

86
            if isinstance(ctx.request, EmbeddingChatRequest):
87
88
89
90
91
92
93
94
                error_check_ret = self._validate_chat_template(
                    request_chat_template=ctx.request.chat_template,
                    chat_template_kwargs=ctx.request.chat_template_kwargs,
                    trust_request_chat_template=self.trust_request_chat_template,
                )
                if error_check_ret is not None:
                    return error_check_ret

95
                _, ctx.engine_prompts = await self._preprocess_chat(
96
97
                    ctx.request,
                    ctx.request.messages,
98
99
100
                    default_template=self.chat_template,
                    default_template_content_format=self.chat_template_content_format,
                    default_template_kwargs=None,
101
                )
102
            elif isinstance(ctx.request, EmbeddingCompletionRequest):
103
104
105
106
                ctx.engine_prompts = await self._preprocess_completion(
                    ctx.request,
                    prompt_input=ctx.request.input,
                    prompt_embeds=None,
107
                )
108
109
110
            else:
                return self.create_error_response("Invalid classification request type")

111
            return None
112
        except (ValueError, TypeError) as e:
113
114
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))
115

116
    def _build_response(
117
        self,
118
119
120
        ctx: EmbeddingServeContext,
    ) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse:
        final_res_batch_checked = ctx.final_res_batch
121

122
123
124
        encoding_format = ctx.request.encoding_format
        embed_dtype = ctx.request.embed_dtype
        endianness = ctx.request.endianness
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

        def encode_float_base64():
            items: list[EmbeddingResponseData] = []
            num_prompt_tokens = 0

            for idx, final_res in enumerate(final_res_batch_checked):
                item = EmbeddingResponseData(
                    index=idx,
                    embedding=encode_pooling_output(
                        final_res,
                        encoding_format=encoding_format,
                        embed_dtype=embed_dtype,
                        endianness=endianness,
                    ),
                )
                prompt_token_ids = final_res.prompt_token_ids

                items.append(item)
                num_prompt_tokens += len(prompt_token_ids)

            usage = UsageInfo(
                prompt_tokens=num_prompt_tokens,
                total_tokens=num_prompt_tokens,
148
149
            )

150
151
152
153
154
155
156
            return EmbeddingResponse(
                id=ctx.request_id,
                created=ctx.created_time,
                model=ctx.model_name,
                data=items,
                usage=usage,
            )
157

158
159
        def encode_bytes(bytes_only: bool) -> EmbeddingBytesResponse:
            content, items, usage = encode_pooling_bytes(
160
161
162
163
                pooling_outputs=final_res_batch_checked,
                embed_dtype=embed_dtype,
                endianness=endianness,
            )
164

165
166
167
168
169
170
171
172
173
174
175
176
177
178
            headers = (
                None
                if bytes_only
                else {
                    "metadata": json.dumps(
                        {
                            "id": ctx.request_id,
                            "created": ctx.created_time,
                            "model": ctx.model_name,
                            "data": items,
                            "usage": usage,
                        }
                    )
                }
179
180
            )

181
182
            return EmbeddingBytesResponse(content=content, headers=headers)

183
184
        if encoding_format == "float" or encoding_format == "base64":
            return encode_float_base64()
185
186
        elif encoding_format == "bytes" or encoding_format == "bytes_only":
            return encode_bytes(bytes_only=encoding_format == "bytes_only")
187
188
        else:
            assert_never(encoding_format)
189

190
191
192
193
194
195
    def _get_max_position_embeddings(self) -> int:
        """Get the model's effective maximum sequence length for chunking."""
        return self.model_config.max_model_len

    def _should_use_chunked_processing(self, request) -> bool:
        """Check if chunked processing should be used for this request."""
196
197
198
199
        return (
            isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest))
            and self.supports_chunked_processing
        )
200
201
202
203

    async def _process_chunked_request(
        self,
        ctx: EmbeddingServeContext,
204
        token_ids: list[int],
205
206
        pooling_params: PoolingParams,
        trace_headers: Mapping[str, str] | None,
207
208
209
210
211
212
213
214
215
        prompt_idx: int,
    ) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
        """Process a single prompt using chunked processing."""
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []

        # Split into chunks using max_position_embeddings
        max_pos_embeddings = self._get_max_position_embeddings()
        # Process all chunks for MEAN aggregation
        for chunk_idx, chunk_tokens in enumerate(
216
217
            chunk_list(token_ids, max_pos_embeddings)
        ):
218
            # Create a request ID for this chunk
219
            chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
220
221

            # Create engine prompt for this chunk
222
            chunk_engine_prompt = TokensPrompt(prompt_token_ids=chunk_tokens)
223
224

            # Log the chunk
225
226
            self._log_inputs(
                chunk_request_id,
227
                chunk_engine_prompt,
228
229
230
                params=pooling_params,
                lora_request=ctx.lora_request,
            )
231

232
233
234
            tok_params = ctx.request.build_tok_params(self.model_config)
            tokenization_kwargs = tok_params.get_encode_kwargs()

235
236
237
238
239
240
            # Create generator for this chunk and wrap it to return indices
            original_generator = self.engine_client.encode(
                chunk_engine_prompt,
                pooling_params,
                chunk_request_id,
                lora_request=ctx.lora_request,
241
                tokenization_kwargs=tokenization_kwargs,
242
                trace_headers=trace_headers,
243
                priority=ctx.request.priority,
244
245
246
247
248
249
250
251
            )

            generators.append(original_generator)

        return generators

    def _validate_input(
        self,
252
        request: object,
253
254
        input_ids: list[int],
        input_text: str,
255
    ) -> TokensPrompt:
256
257
258
259
        """Override to support chunked processing for embedding requests."""
        token_num = len(input_ids)

        # Note: EmbeddingRequest doesn't have max_tokens
260
        if isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest)):
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
            # Check if chunked processing is enabled for pooling models
            enable_chunked = self._should_use_chunked_processing(request)

            # Use max_position_embeddings for chunked processing decisions
            max_pos_embeddings = self._get_max_position_embeddings()

            # Determine the effective max length for validation
            if self.max_embed_len is not None:
                # Use max_embed_len for validation instead of max_model_len
                length_type = "maximum embedding input length"
                max_length_value = self.max_embed_len
            else:
                # Fall back to max_model_len validation (original behavior)
                length_type = "maximum context length"
                max_length_value = self.max_model_len

            validation_error_msg = (
                "This model's {length_type} is {max_length_value} tokens. "
                "However, you requested {token_num} tokens in the input for "
280
281
                "embedding generation. Please reduce the length of the input."
            )
282
283
284
285
286

            chunked_processing_error_msg = (
                "This model's {length_type} is {max_length_value} tokens. "
                "However, you requested {token_num} tokens in the input for "
                "embedding generation. Please reduce the length of the input "
287
288
                "or enable chunked processing."
            )
289
290
291
292
293
294
295

            # Check if input exceeds max length
            if token_num > max_length_value:
                raise ValueError(
                    validation_error_msg.format(
                        length_type=length_type,
                        max_length_value=max_length_value,
296
297
298
                        token_num=token_num,
                    )
                )
299
300
301
302
303
304
305
306

            # Check for chunked processing
            # when exceeding max_position_embeddings
            if token_num > max_pos_embeddings:
                if enable_chunked:
                    # Allow long inputs when chunked processing is enabled
                    logger.info(
                        "Input length %s exceeds max_position_embeddings "
307
308
309
310
                        "%s, will use chunked processing",
                        token_num,
                        max_pos_embeddings,
                    )
311
312
313
314
315
                else:
                    raise ValueError(
                        chunked_processing_error_msg.format(
                            length_type="maximum position embeddings length",
                            max_length_value=max_pos_embeddings,
316
317
318
                            token_num=token_num,
                        )
                    )
319

320
            return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
321
322
323
324
325
326
327

        # For other request types, use the parent's implementation
        return super()._validate_input(request, input_ids, input_text)

    async def _create_single_prompt_generator(
        self,
        ctx: EmbeddingServeContext,
328
        engine_prompt: TokensPrompt | EmbedsPrompt,
329
        pooling_params: PoolingParams,
330
        trace_headers: Mapping[str, str] | None,
331
        prompt_index: int,
332
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
333
334
335
        """Create a generator for a single prompt using standard processing."""
        request_id_item = f"{ctx.request_id}-{prompt_index}"

336
337
338
339
340
341
        self._log_inputs(
            request_id_item,
            engine_prompt,
            params=pooling_params,
            lora_request=ctx.lora_request,
        )
342

343
344
345
        tok_params = ctx.request.build_tok_params(self.model_config)
        tokenization_kwargs = tok_params.get_encode_kwargs()

346
347
348
349
350
351
        # Return the original generator without wrapping
        return self.engine_client.encode(
            engine_prompt,
            pooling_params,
            request_id_item,
            lora_request=ctx.lora_request,
352
            tokenization_kwargs=tokenization_kwargs,
353
            trace_headers=trace_headers,
354
            priority=ctx.request.priority,
355
356
357
358
        )

    async def _prepare_generators(
        self,
359
        ctx: EmbeddingServeContext,
360
    ) -> ErrorResponse | None:
361
362
363
364
365
366
367
368
369
        """Override to support chunked processing."""
        # Check if we should use chunked processing
        use_chunked = self._should_use_chunked_processing(ctx.request)

        # If no chunked processing needed, delegate to parent class
        if not use_chunked:
            return await super()._prepare_generators(ctx)

        # Custom logic for chunked processing
370
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
371
372

        try:
373
374
375
376
377
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )
378
379
380
381
382
383
384
385
386
387
388
389

            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params

            # Verify and set the task for pooling params
            try:
                pooling_params.verify("embed", self.model_config)
            except ValueError as e:
                return self.create_error_response(str(e))

            if ctx.engine_prompts is None:
390
                return self.create_error_response("Engine prompts not available")
391
392
393
394
395

            max_pos_embeddings = self._get_max_position_embeddings()

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                # Check if this specific prompt needs chunked processing
396
                if "prompt_token_ids" in engine_prompt:
397
398
                    prompt_token_ids = engine_prompt["prompt_token_ids"]  # type: ignore[typeddict-item]

399
                    if len(prompt_token_ids) > max_pos_embeddings:
400
401
                        # Use chunked processing for this prompt
                        chunk_generators = await self._process_chunked_request(
402
403
404
405
406
                            ctx,
                            prompt_token_ids,
                            pooling_params,
                            trace_headers,
                            i,
407
                        )
408
409
410
411
412
                        generators.extend(chunk_generators)
                        continue

                # Normal processing for short prompts or non-token prompts
                generator = await self._create_single_prompt_generator(
413
414
                    ctx, engine_prompt, pooling_params, trace_headers, i
                )
415
416
417
418
419
420
421
422
423
424
425
426
                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    async def _collect_batch(
        self,
427
        ctx: EmbeddingServeContext,
428
    ) -> ErrorResponse | None:
429
430
        """Collect and aggregate batch results
        with support for chunked processing.
431
432

        For chunked requests, performs online aggregation to
433
434
435
436
437
        minimize memory usage.
        For regular requests, collects results normally.
        """
        try:
            if ctx.engine_prompts is None:
438
                return self.create_error_response("Engine prompts not available")
439
440
441
442
443
444
445
446

            # Check if we used chunked processing
            use_chunked = self._should_use_chunked_processing(ctx.request)

            if not use_chunked:
                return await super()._collect_batch(ctx=ctx)

            if ctx.result_generator is None:
447
                return self.create_error_response("Result generator not available")
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467

            # Online aggregation for chunked requests to
            # minimize memory usage
            # Track aggregation state for each prompt
            prompt_aggregators: dict[int, dict[str, Any]] = {}
            short_prompts_results: dict[int, PoolingRequestOutput] = {}

            async for result_idx, result in ctx.result_generator:
                if "-chunk-" in result.request_id:
                    # Extract prompt_idx from chunked request_id
                    parts = result.request_id.split("-")
                    try:
                        prompt_idx = int(parts[parts.index("prompt") + 1])
                    except (ValueError, IndexError):
                        # Fallback: extract from result_idx if parsing fails
                        prompt_idx = result_idx

                    # Initialize aggregator for this prompt if needed
                    if prompt_idx not in prompt_aggregators:
                        prompt_aggregators[prompt_idx] = {
468
469
470
471
                            "weighted_sum": None,
                            "total_weight": 0,
                            "chunk_count": 0,
                            "request_id": result.request_id.split("-chunk-")[0],
472
473
474
475
476
477
478
479
480
481
482
                        }

                    aggregator = prompt_aggregators[prompt_idx]

                    # MEAN pooling with online weighted averaging
                    # Ensure result is PoolingRequestOutput
                    # for embedding processing
                    if not isinstance(result, PoolingRequestOutput):
                        return self.create_error_response(
                            f"Expected PoolingRequestOutput for "
                            f"chunked embedding, got "
483
484
                            f"{type(result).__name__}"
                        )
485
486
487

                    # Handle both PoolingOutput and
                    # EmbeddingOutput types
488
                    if hasattr(result.outputs, "data"):
489
490
                        # PoolingOutput case
                        embedding_data = result.outputs.data
491
                    elif hasattr(result.outputs, "embedding"):
492
493
494
495
496
                        # EmbeddingOutput case -
                        # convert embedding list to tensor
                        embedding_data = result.outputs.embedding
                    else:
                        return self.create_error_response(
497
498
                            f"Unsupported output type: {type(result.outputs).__name__}"
                        )
499
500

                    if not isinstance(embedding_data, torch.Tensor):
501
502
503
                        embedding_data = torch.tensor(
                            embedding_data, dtype=torch.float32
                        )
504
505
506

                    if result.prompt_token_ids is None:
                        return self.create_error_response(
507
508
                            "prompt_token_ids cannot be None for chunked processing"
                        )
509
510
                    weight = len(result.prompt_token_ids)

511
                    weighted_embedding = embedding_data.to(dtype=torch.float32) * weight
512

513
                    if aggregator["weighted_sum"] is None:
514
                        # First chunk
515
                        aggregator["weighted_sum"] = weighted_embedding
516
517
                    else:
                        # Accumulate
518
                        aggregator["weighted_sum"] += weighted_embedding
519

520
521
                    aggregator["total_weight"] += weight
                    aggregator["chunk_count"] += 1
522
523
524
525
526
527
528
529
530
                else:
                    # Non-chunked result - extract prompt_idx from request_id
                    parts = result.request_id.split("-")
                    try:
                        # Last part should be prompt index
                        prompt_idx = int(parts[-1])
                    except (ValueError, IndexError):
                        prompt_idx = result_idx  # Fallback to result_idx

531
                    short_prompts_results[prompt_idx] = result
532
533

            # Finalize aggregated results
534
            final_res_batch: list[PoolingRequestOutput] = []
535
536
537
538
539
540
541
            num_prompts = len(ctx.engine_prompts)

            for prompt_idx in range(num_prompts):
                if prompt_idx in prompt_aggregators:
                    # Finalize MEAN aggregation for this chunked prompt
                    aggregator = prompt_aggregators[prompt_idx]

542
543
                    weighted_sum = aggregator["weighted_sum"]
                    total_weight = aggregator["total_weight"]
544

545
546
547
548
549
550
                    if (
                        weighted_sum is not None
                        and isinstance(weighted_sum, torch.Tensor)
                        and isinstance(total_weight, (int, float))
                        and total_weight > 0
                    ):
551
552
553
554
555
                        # Compute final mean embedding
                        final_embedding = weighted_sum / total_weight

                        # Create a PoolingRequestOutput
                        # for the aggregated result
556
                        pooling_output_data = PoolingOutput(data=final_embedding)
557
558

                        # Get original prompt token IDs for this prompt
559
                        original_prompt = ctx.engine_prompts[prompt_idx]
560
                        if "prompt_token_ids" not in original_prompt:
561
                            return self.create_error_response(
562
563
                                f"Chunked prompt {prompt_idx} does not contain "
                                "token IDs"
564
                            )
565

566
                        original_token_ids = original_prompt["prompt_token_ids"]  # type: ignore[typeddict-item]
567
568

                        pooling_request_output = PoolingRequestOutput(
569
                            request_id=aggregator["request_id"],
570
571
                            prompt_token_ids=original_token_ids,
                            outputs=pooling_output_data,
572
                            num_cached_tokens=0,
573
574
                            finished=True,
                        )
575
576
577
578

                        final_res_batch.append(pooling_request_output)
                    else:
                        return self.create_error_response(
579
580
                            f"Failed to aggregate chunks for prompt {prompt_idx}"
                        )
581
                elif prompt_idx in short_prompts_results:
582
                    final_res_batch.append(short_prompts_results[prompt_idx])
583
584
                else:
                    return self.create_error_response(
585
586
                        f"Result not found for prompt {prompt_idx}"
                    )
587

588
            ctx.final_res_batch = final_res_batch
589
590
591
592
593
594

            return None

        except Exception as e:
            return self.create_error_response(str(e))

595
596
597
    async def create_embedding(
        self,
        request: EmbeddingRequest,
598
599
        raw_request: Request | None = None,
    ) -> EmbeddingResponse | ErrorResponse:
600
601
602
603
604
605
        """
        Embedding API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/embeddings/create
        for the API specification. This API mimics the OpenAI Embedding API.
        """
606
        model_name = self.models.model_name()
607
608
        request_id = (
            f"{self.request_id_prefix}-"
609
610
            f"{self._base_request_id(raw_request, request.request_id)}"
        )
611
612
613
614
615
616
617
618

        ctx = EmbeddingServeContext(
            request=request,
            raw_request=raw_request,
            model_name=model_name,
            request_id=request_id,
        )

619
        return await self.handle(ctx)  # type: ignore[return-value]
620

621
622
    def _create_pooling_params(
        self,
623
        ctx: EmbeddingServeContext,
624
    ) -> PoolingParams | ErrorResponse:
625
626
627
        pooling_params = super()._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params
628
629

        try:
630
            pooling_params.verify("embed", self.model_config)
631
632
633
        except ValueError as e:
            return self.create_error_response(str(e))

634
        return pooling_params