decode_worker.py 14.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import logging
18
import os
19
import signal
20
from typing import Optional
21

22
import connect
23
import torch
24
from components.disagg_router import PyDisaggregatedRouter
25
26
from components.encode_worker import VllmEncodeWorker
from components.prefill_worker import VllmPrefillWorker
27
from transformers import LlavaForConditionalGeneration
28
from utils.logging import check_required_workers
29
30
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
31
32
33
34
35
36
37
38
39
40
41
from utils.protocol import (
    EncodeRequest,
    EncodeResponse,
    MyRequestOutput,
    vLLMMultimodalRequest,
)
from utils.vllm import parse_vllm_args
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
42
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
43
44
from vllm.sampling_params import RequestOutputKind

45
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
46
47
48
49
50
51
52
53
54
55
56

logger = logging.getLogger(__name__)


@service(
    dynamo={
        "namespace": "dynamo",
    },
    resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
    workers=1,
)
57
class VllmDecodeWorker:
58
    # For disaggregated serving, we need to link the prefill worker to the vllm worker
59
    prefill_worker = depends(VllmPrefillWorker)
60
    # For aggregated serving, we need to link the encode worker to the vllm worker.
61
    encode_worker = depends(VllmEncodeWorker)
62
63
64
65

    def __init__(self):
        self.client = None
        self.min_workers = 1
66
        self.disaggregated_router: Optional[PyDisaggregatedRouter] = None
67
68
69
70
71
72
73
74
        class_name = self.__class__.__name__
        self.engine_args = parse_vllm_args(class_name, "")
        self.do_remote_prefill = self.engine_args.remote_prefill
        self.model_name = (
            self.engine_args.served_model_name
            if self.engine_args.served_model_name is not None
            else "vllm"
        )
75
76
77
78
79
80
81
        self._prefill_queue_nats_server = os.getenv(
            "NATS_SERVER", "nats://localhost:4222"
        )
        self._prefill_queue_stream_name = self.model_name
        logger.info(
            f"Prefill queue: {self._prefill_queue_nats_server}:{self._prefill_queue_stream_name}"
        )
82
83

        if self.engine_args.remote_prefill:
84
85
86
87
88
89
90
91
92
93
94
95
96
            if self.engine_args.enable_chunked_prefill is not False:
                logger.info("Chunked prefill is not supported yet, setting to False")
                self.engine_args.enable_chunked_prefill = False

            if self.engine_args.preemption_mode != "swap":
                logger.info("Preemption mode is not supported yet, setting to swap")
                self.engine_args.preemption_mode = "swap"

            if self.engine_args.pipeline_parallel_size != 1:
                logger.info("Pipeline parallel size is not supported yet, setting to 1")
                self.engine_args.pipeline_parallel_size = 1

        if self.engine_args.router == "kv":
97
            raise NotImplementedError(
98
                "Multimodal requests are not supported for kv router mode"
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
            )

        signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
        signal.signal(signal.SIGINT, self.shutdown_vllm_engine)

    @async_on_start
    async def async_init(self):
        self._engine_context = build_async_engine_client_from_engine_args(
            self.engine_args
        )
        if self._engine_context is not None:
            self.engine_client = await self._engine_context.__aenter__()
        else:
            raise RuntimeError("Failed to initialize engine client")

        if self.engine_args.router == "kv":
            raise NotImplementedError(
                "Multimodal requests are not supported for kv router mode"
            )

        runtime = dynamo_context["runtime"]

121
122
123
124
        if self.do_remote_prefill:
            metadata = self.engine_client.nixl_metadata
            metadata_store = NixlMetadataStore("dynamo", runtime)
            await metadata_store.put(metadata.engine_id, metadata)
125

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
            if self.engine_args.conditional_disagg:
                self.disaggregated_router = PyDisaggregatedRouter(
                    runtime,
                    self.model_name,
                    max_local_prefill_length=self.engine_args.max_local_prefill_length,
                    max_prefill_queue_size=self.engine_args.max_prefill_queue_size,
                )
                await self.disaggregated_router.async_init()
            else:
                self.disaggregated_router = None

            model = LlavaForConditionalGeneration.from_pretrained(
                self.engine_args.model
            )
            vision_tower = model.vision_tower
            self.embedding_size = (
                vision_tower.vision_model.embeddings.position_embedding.num_embeddings
            )
        else:
145
146
147
148
149
            EMBEDDINGS_SHAPE = (1, 577, 4096)
            EMBEDDINGS_DTYPE = torch.float16
            EMBEDDINGS_DEVICE = "cuda"

            enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address()  # type: ignore
150
151
152
153
154
155
            self.encode_worker_client = (
                await runtime.namespace(enc_comp_ns)
                .component(enc_comp_name)
                .endpoint("encode")
                .client()
            )
156

157
158
159
160
161
162
163
164
165
166
167
168
            self._connector = connect.Connector(runtime=runtime, namespace=enc_comp_ns)
            await self._connector.initialize()

            # Create a longer-lived buffer for receiving the image embeddings.
            embeddings = torch.empty(
                EMBEDDINGS_SHAPE, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
            )
            descriptor = connect.Descriptor(embeddings)
            # Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
            descriptor.register_memory(self._connector)
            self._embeddings_descriptor = (embeddings, descriptor)

169
170
            await check_required_workers(self.encode_worker_client, self.min_workers)
            self.disaggregated_router = None
171
172

        logger.info("Initialization complete.")
173
174
175
176
177
178
179

    def shutdown_vllm_engine(self, signum, frame):
        """Shutdown the background loop"""
        logger.info(f"Received signal {signum}, shutting down")
        loop = asyncio.get_event_loop()
        try:
            self.engine_client.close()
180
            logger.info("Shutdown complete.")
181
182
183
184
185
        except Exception as e:
            logger.error(f"Error during shutdown: {e}")
        finally:
            loop.stop()

186
187
188
189
190
191
192
193
194
195
    def get_remote_prefill_request_callback(self):
        async def callback(request: RemotePrefillRequest):
            async with PrefillQueue.get_instance(
                nats_server=self._prefill_queue_nats_server,
                stream_name=self._prefill_queue_stream_name,
            ) as prefill_queue:
                await prefill_queue.enqueue_prefill_request(request)

        return callback

196
    @endpoint()
197
    async def generate(self, request: vLLMMultimodalRequest):
198
199
200
201
202
203
        request_id = request.request_id
        image_url = request.image_url
        logger.info(
            f"Received multimodal request {{ id: {request_id}, image_url: '{image_url}' }}."
        )
        embeddings = None
204
        if self.do_remote_prefill:
205
206
207
208
209
            logger.debug(
                f"Disaggregated: request {{ id: {request_id}, image_url: '{image_url}' }}"
                " prefill worker will populate the decode model's key-value cache ahead of time;"
                " no direct encode worker interaction required."
            )
210
211
212
213
214
215
216
217
218
219
220
221
222
223
            if self.disaggregated_router is not None:
                async with PrefillQueue.get_instance(
                    nats_server=self._prefill_queue_nats_server,
                    stream_name=self._prefill_queue_stream_name,
                ) as prefill_queue:
                    prefill_queue_size = await prefill_queue.get_queue_size()
                disagg_router_decision = await self.disaggregated_router.prefill_remote(
                    len(request.engine_prompt["prompt_token_ids"]),
                    request.prefix_hit_rate,
                    prefill_queue_size,
                )
            else:
                # always prefill remotely if no disaggregated router is provided
                disagg_router_decision = True
224

225
            if self.do_remote_prefill and disagg_router_decision:
226
227
228
                logger.debug(
                    f"Prefilling remotely for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}"
                )
229
230
231
232
233
                remote_prefill_params = RemotePrefillParams(
                    is_remote_prefill=True,
                    remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
                    # Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest
                    multimodal_data_source={
234
                        "image_url": image_url,
235
236
237
238
                    },
                )
            else:
                remote_prefill_params = None
239
240
                logger.debug(
                    f"Prefilling locally for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}"
241
                )
242

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            # The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
            # As a workaround, here we manually insert some placeholder dummy tokens based on the embedding size
            # so that decode worker can pre-allocate the memory with the correct size.
            # The structure of the prompt will be like: "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:".
            # Since the "<image>" token is included in the prompt, only need to insert (embedding_size - 1) dummy tokens after the image token.
            IMAGE_TOKEN_ID = 32000
            DUMMY_TOKEN_ID = 0
            # Find the index of the image token in the prompt token ids
            image_token_index = request.engine_prompt["prompt_token_ids"].index(
                IMAGE_TOKEN_ID
            )
            dummy_token_index = image_token_index + 1
            prompt_ids = (
                request.engine_prompt["prompt_token_ids"][:dummy_token_index]
                + [DUMMY_TOKEN_ID] * (self.embedding_size - 1)
                + request.engine_prompt["prompt_token_ids"][dummy_token_index:]
259
260
            )

261
        else:
262
263
264
            logger.debug(
                f"Aggregated: request {{ id: {request_id}, image_url: '{image_url}' }}"
                " no prefill worker available, embeddings directly from encode worker."
265
            )
266
267
268
269
270
271
272
273
274
275
276
            # Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
            # Doing this avoids unnessesary memory de/registration with NIXL.
            embeddings, descriptor = self._embeddings_descriptor

            with self._connector.create_writable(descriptor) as writable:
                # Extract serialized metadata about the operation from the writable operation,
                # and use it to create a new EncodeRequest.
                encode_request = EncodeRequest(
                    request_id=request_id,
                    image_url=image_url,
                    serialized_request=writable.to_serialized(),
277
                )
278
279
280
                logger.debug(f"Encode request: {encode_request.model_dump_json()}")
                encode_generator = await self.encode_worker_client.round_robin(
                    encode_request.model_dump_json()
281
282
                )

283
284
285
286
287
288
289
290
291
292
293
294
295
296
                async for encode_response in encode_generator:
                    encode_output = EncodeResponse.model_validate_json(
                        encode_response.data()
                    )
                    logger.info(
                        f"Received response: {{ id: {encode_output.request_id} }}"
                    )

                # Wait for the write operation to complete.
                # This will block until the write operation is complete.
                # This await should be a no-op since we've already received a response from the encode worker.
                await writable.wait_for_completion()
                # At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.

297
298
            remote_prefill_params = None
            logger.info(
299
                f"Prefilling locally for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}"
300
301
            )
            prompt_ids = request.engine_prompt["prompt_token_ids"]
302
303
304

        # rust HTTP requires Delta streaming
        request.sampling_params.output_kind = RequestOutputKind.DELTA
305

306
307
308
309
310
311
312
        # When using aggregated serving, the encode worker will have provided the key-value cache updates via the prefill worker.
        # When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
        if embeddings is not None:
            logger.debug(
                "Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
            )
            multi_modal_data = {"image": embeddings}
313
        else:
314
315
316
            logger.debug(
                "Disaggregated: no embedding data required as prefill will have provided key-value cache updates via encode worker."
            )
317
318
            multi_modal_data = None

319
320
        async for response in self.engine_client.generate(
            prompt=TokensPrompt(
321
322
                prompt_token_ids=prompt_ids,
                multi_modal_data=multi_modal_data,
323
324
325
326
327
            ),
            sampling_params=request.sampling_params,
            request_id=request.request_id,
            remote_prefill_params=remote_prefill_params,
        ):
328
329
330
            logger.debug(
                f"Yeilding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}"
            )
331
332
333
334
335
336
337
338
            yield MyRequestOutput(
                request_id=response.request_id,
                prompt=response.prompt,
                prompt_token_ids=response.prompt_token_ids,
                prompt_logprobs=response.prompt_logprobs,
                outputs=response.outputs,
                finished=response.finished,
            ).model_dump_json()