io_processor.py 8.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, cast

import torch

from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.typing import PoolingServeContext
from vllm.inputs.data import ProcessorInputs, token_inputs
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils.collection_utils import chunk_list


class EmbedIOProcessor(PoolingIOProcessor):
    name = "embedding"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert self.model_config.pooler_config is not None

        self.pooler_config = self.model_config.pooler_config
        self.enable_chunked_processing = self.pooler_config.enable_chunked_processing

    #################################################################
    # Long Text Embedding with Chunked Processing
    # PTAL: examples/pooling/embed/openai_embedding_long_text

    def pre_process_online(self, ctx: PoolingServeContext):
        super().pre_process_online(ctx)

        if not self.enable_chunked_processing:
            return None

        if ctx.engine_prompts is None:
            raise ValueError("Engine prompts not available")

        ctx.intermediates = ctx.engine_prompts
        request_id = ctx.request_id
        max_model_len = self.model_config.max_model_len
        chunked_engine_prompts: list[ProcessorInputs] = []
        prompt_request_ids: list[str] = []
        for prompt_idx, engine_prompt in enumerate(ctx.engine_prompts):
            token_ids = engine_prompt.get("prompt_token_ids", None)
            if token_ids is None:
                raise NotImplementedError(
                    "Long Text Embedding with Chunked Processing does "
                    "not support EmbedsPrompt and EncoderDecoderInputs."
                )

            prompt_token_ids = cast(list[int], token_ids)

            for chunk_idx, chunk_tokens in enumerate(
                chunk_list(prompt_token_ids, max_model_len)
            ):
                chunked_engine_prompts.append(
                    token_inputs(prompt_token_ids=chunk_tokens)
                )
                prompt_request_ids.append(
                    f"{request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
                )

        ctx.engine_prompts = chunked_engine_prompts
        ctx.prompt_request_ids = prompt_request_ids
        return None

    def post_process_online(
        self,
        ctx: PoolingServeContext,
    ):
        if ctx.final_res_batch is None:
            raise ValueError("Final response batch not available")

        if not self.enable_chunked_processing:
            return super().post_process_online(ctx)

        # Online aggregation for chunked requests to
        # minimize memory usage
        # Track aggregation state for each prompt
        prompt_aggregators: dict[int, dict[str, Any]] = {}
        short_prompts_results: dict[int, PoolingRequestOutput] = {}
        for result_idx, result in enumerate(ctx.final_res_batch):
            if "-chunk-" not in result.request_id:
                # Non-chunked result - extract prompt_idx from request_id
                parts = result.request_id.split("-")
                try:
                    # Last part should be prompt index
                    prompt_idx = int(parts[-1])
                except (ValueError, IndexError):
                    prompt_idx = result_idx  # Fallback to result_idx

                short_prompts_results[prompt_idx] = result
            else:
                # Extract prompt_idx from chunked request_id
                parts = result.request_id.split("-")
                try:
                    prompt_idx = int(parts[parts.index("prompt") + 1])
                except (ValueError, IndexError):
                    # Fallback: extract from result_idx if parsing fails
                    prompt_idx = result_idx

                # Initialize aggregator for this prompt if needed
                if prompt_idx not in prompt_aggregators:
                    prompt_aggregators[prompt_idx] = {
                        "weighted_sum": None,
                        "total_weight": 0,
                        "chunk_count": 0,
                        "request_id": result.request_id.split("-chunk-")[0],
                    }

                aggregator = prompt_aggregators[prompt_idx]

                # MEAN pooling with online weighted averaging
                # Ensure result is PoolingRequestOutput
                # for embedding processing
                if not isinstance(result, PoolingRequestOutput):
                    raise ValueError(
                        f"Expected PoolingRequestOutput for "
                        f"chunked embedding, got "
                        f"{type(result).__name__}"
                    )
                if result.prompt_token_ids is None:
                    raise ValueError(
                        "prompt_token_ids cannot be None for chunked processing"
                    )

                weight = len(result.prompt_token_ids)
                embedding_data = result.outputs.data
                weighted_embedding = embedding_data.to(dtype=torch.float32) * weight

                if aggregator["weighted_sum"] is None:
                    # First chunk
                    aggregator["weighted_sum"] = weighted_embedding
                else:
                    # Accumulate
                    aggregator["weighted_sum"] += weighted_embedding

                aggregator["total_weight"] += weight
                aggregator["chunk_count"] += 1

        if ctx.intermediates is None:
            raise ValueError("Original prompts inputs not available")

        original_engine_prompts = cast(list[ProcessorInputs], ctx.intermediates)
        num_prompts = len(original_engine_prompts)

        # Finalize aggregated results
        final_res_batch: list[PoolingRequestOutput] = []
        for prompt_idx in range(num_prompts):
            if prompt_idx in prompt_aggregators:
                # Finalize MEAN aggregation for this chunked prompt
                aggregator = prompt_aggregators[prompt_idx]

                weighted_sum = aggregator["weighted_sum"]
                total_weight = aggregator["total_weight"]

                if (
                    weighted_sum is not None
                    and isinstance(weighted_sum, torch.Tensor)
                    and isinstance(total_weight, (int, float))
                    and total_weight > 0
                ):
                    # Compute final mean embedding
                    final_embedding = weighted_sum / total_weight

                    # Create a PoolingRequestOutput
                    # for the aggregated result
                    pooling_output_data = PoolingOutput(data=final_embedding)

                    # Get original prompt token IDs for this prompt
                    original_prompt = original_engine_prompts[prompt_idx]
                    token_ids = original_prompt.get("prompt_token_ids", None)
                    if token_ids is None:
                        raise NotImplementedError(
                            "Long Text Embedding with Chunked Processing does "
                            "not support EmbedsPrompt and EncoderDecoderInputs."
                        )

                    original_token_ids = cast(list[int], token_ids)
                    pooling_request_output = PoolingRequestOutput(
                        request_id=aggregator["request_id"],
                        prompt_token_ids=original_token_ids,
                        outputs=pooling_output_data,
                        num_cached_tokens=0,
                        finished=True,
                    )

                    final_res_batch.append(pooling_request_output)
                else:
                    raise ValueError(
                        f"Failed to aggregate chunks for prompt {prompt_idx}"
                    )
            elif prompt_idx in short_prompts_results:
                final_res_batch.append(short_prompts_results[prompt_idx])
            else:
                raise ValueError(f"Result not found for prompt {prompt_idx}")

        ctx.final_res_batch = final_res_batch
        return None