decode_worker.py 15.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import logging
18
import os
19
import signal
20
from typing import Optional
21

22
import connect
23
import torch
24
from components.disagg_router import PyDisaggregatedRouter
25
26
from components.encode_worker import VllmEncodeWorker
from components.prefill_worker import VllmPrefillWorker
27
from transformers import LlavaForConditionalGeneration
28
from utils.logging import check_required_workers
29
30
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
31
32
33
34
35
36
37
38
39
40
41
from utils.protocol import (
    EncodeRequest,
    EncodeResponse,
    MyRequestOutput,
    vLLMMultimodalRequest,
)
from utils.vllm import parse_vllm_args
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
42
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
43
44
from vllm.sampling_params import RequestOutputKind

45
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
46
47
48
49
50
51
52
53
54
55
56

logger = logging.getLogger(__name__)


@service(
    dynamo={
        "namespace": "dynamo",
    },
    resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
    workers=1,
)
57
class VllmDecodeWorker:
58
    # For disaggregated serving, we need to link the prefill worker to the vllm worker
59
    prefill_worker = depends(VllmPrefillWorker)
60
    # For aggregated serving, we need to link the encode worker to the vllm worker.
61
    encode_worker = depends(VllmEncodeWorker)
62
63
64
65

    def __init__(self):
        self.client = None
        self.min_workers = 1
66
        self.disaggregated_router: Optional[PyDisaggregatedRouter] = None
67
68
69
70
71
72
73
74
        class_name = self.__class__.__name__
        self.engine_args = parse_vllm_args(class_name, "")
        self.do_remote_prefill = self.engine_args.remote_prefill
        self.model_name = (
            self.engine_args.served_model_name
            if self.engine_args.served_model_name is not None
            else "vllm"
        )
75
76
77
78
79
80
81
        self._prefill_queue_nats_server = os.getenv(
            "NATS_SERVER", "nats://localhost:4222"
        )
        self._prefill_queue_stream_name = self.model_name
        logger.info(
            f"Prefill queue: {self._prefill_queue_nats_server}:{self._prefill_queue_stream_name}"
        )
82
83

        if self.engine_args.remote_prefill:
84
85
86
87
88
89
90
91
92
93
94
95
96
            if self.engine_args.enable_chunked_prefill is not False:
                logger.info("Chunked prefill is not supported yet, setting to False")
                self.engine_args.enable_chunked_prefill = False

            if self.engine_args.preemption_mode != "swap":
                logger.info("Preemption mode is not supported yet, setting to swap")
                self.engine_args.preemption_mode = "swap"

            if self.engine_args.pipeline_parallel_size != 1:
                logger.info("Pipeline parallel size is not supported yet, setting to 1")
                self.engine_args.pipeline_parallel_size = 1

        if self.engine_args.router == "kv":
97
            raise NotImplementedError(
98
                "Multimodal requests are not supported for kv router mode"
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
            )

        signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
        signal.signal(signal.SIGINT, self.shutdown_vllm_engine)

    @async_on_start
    async def async_init(self):
        self._engine_context = build_async_engine_client_from_engine_args(
            self.engine_args
        )
        if self._engine_context is not None:
            self.engine_client = await self._engine_context.__aenter__()
        else:
            raise RuntimeError("Failed to initialize engine client")

        if self.engine_args.router == "kv":
            raise NotImplementedError(
                "Multimodal requests are not supported for kv router mode"
            )

        runtime = dynamo_context["runtime"]

121
122
123
124
        if self.do_remote_prefill:
            metadata = self.engine_client.nixl_metadata
            metadata_store = NixlMetadataStore("dynamo", runtime)
            await metadata_store.put(metadata.engine_id, metadata)
125

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
            if self.engine_args.conditional_disagg:
                self.disaggregated_router = PyDisaggregatedRouter(
                    runtime,
                    self.model_name,
                    max_local_prefill_length=self.engine_args.max_local_prefill_length,
                    max_prefill_queue_size=self.engine_args.max_prefill_queue_size,
                )
                await self.disaggregated_router.async_init()
            else:
                self.disaggregated_router = None

            model = LlavaForConditionalGeneration.from_pretrained(
                self.engine_args.model
            )
            vision_tower = model.vision_tower
            self.embedding_size = (
                vision_tower.vision_model.embeddings.position_embedding.num_embeddings
            )
        else:
145
146
147
148
149
            EMBEDDINGS_SHAPE = (1, 577, 4096)
            EMBEDDINGS_DTYPE = torch.float16
            EMBEDDINGS_DEVICE = "cuda"

            enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address()  # type: ignore
150
151
152
153
154
155
            self.encode_worker_client = (
                await runtime.namespace(enc_comp_ns)
                .component(enc_comp_name)
                .endpoint("encode")
                .client()
            )
156

157
158
159
160
161
162
163
164
165
166
167
168
            self._connector = connect.Connector(runtime=runtime, namespace=enc_comp_ns)
            await self._connector.initialize()

            # Create a longer-lived buffer for receiving the image embeddings.
            embeddings = torch.empty(
                EMBEDDINGS_SHAPE, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
            )
            descriptor = connect.Descriptor(embeddings)
            # Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
            descriptor.register_memory(self._connector)
            self._embeddings_descriptor = (embeddings, descriptor)

169
170
            await check_required_workers(self.encode_worker_client, self.min_workers)
            self.disaggregated_router = None
171

172
173
        configuration = "Disaggregated" if self.do_remote_prefill else "Aggregated"
        logger.info("Initialization complete { configuration: %s }.", configuration)
174
175
176
177
178
179
180

    def shutdown_vllm_engine(self, signum, frame):
        """Shutdown the background loop"""
        logger.info(f"Received signal {signum}, shutting down")
        loop = asyncio.get_event_loop()
        try:
            self.engine_client.close()
181
            logger.info("Shutdown complete.")
182
183
184
185
186
        except Exception as e:
            logger.error(f"Error during shutdown: {e}")
        finally:
            loop.stop()

187
188
189
190
191
192
193
194
195
196
    def get_remote_prefill_request_callback(self):
        async def callback(request: RemotePrefillRequest):
            async with PrefillQueue.get_instance(
                nats_server=self._prefill_queue_nats_server,
                stream_name=self._prefill_queue_stream_name,
            ) as prefill_queue:
                await prefill_queue.enqueue_prefill_request(request)

        return callback

197
    @endpoint()
198
    async def generate(self, request: vLLMMultimodalRequest):
199
        request_id = request.request_id
200
        logger.info(f"Received multimodal request {{ id: {request_id} }}.")
201
        if self.do_remote_prefill:
202
203
204
205
206
            (
                prompt_ids,
                multi_modal_data,
                remote_prefill_params,
            ) = await self.remote_prefill(request)
207

208
        else:
209
210
211
212
213
            (
                prompt_ids,
                multi_modal_data,
                remote_prefill_params,
            ) = await self.local_prefill(request)
214
215
216

        # rust HTTP requires Delta streaming
        request.sampling_params.output_kind = RequestOutputKind.DELTA
217

218
219
        async for response in self.engine_client.generate(
            prompt=TokensPrompt(
220
221
                prompt_token_ids=prompt_ids,
                multi_modal_data=multi_modal_data,
222
223
224
225
226
            ),
            sampling_params=request.sampling_params,
            request_id=request.request_id,
            remote_prefill_params=remote_prefill_params,
        ):
227
228
229
            logger.debug(
                f"Yeilding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}"
            )
230
231
232
233
234
235
236
237
            yield MyRequestOutput(
                request_id=response.request_id,
                prompt=response.prompt,
                prompt_token_ids=response.prompt_token_ids,
                prompt_logprobs=response.prompt_logprobs,
                outputs=response.outputs,
                finished=response.finished,
            ).model_dump_json()
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373

    async def local_prefill(self, request: vLLMMultimodalRequest) -> tuple:
        """
        Handles local prefill in aggregated serving mode.

        Interacts with the encode worker to obtain image embeddings and returns
        the original prompt tokens with multi-modal data for local processing.

        Args:
            request: The multimodal request containing image URL and prompt data

        Returns:
            Tuple of (prompt_ids, multi_modal_data, remote_prefill_params)
        """
        logger.debug(
            f"Aggregated: request {{ id: {request.request_id} }}"
            " no prefill worker available, embeddings directly from encode worker."
        )
        # Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
        # Doing this avoids unnessesary memory de/registration with NIXL.
        embeddings, descriptor = self._embeddings_descriptor

        with self._connector.create_writable(descriptor) as writable:
            # Extract serialized metadata about the operation from the writable operation,
            # and use it to create a new EncodeRequest.
            encode_request = EncodeRequest(
                request_id=request.request_id,
                image_url=request.image_url,
                serialized_request=writable.to_serialized(),
            )
            logger.debug(f"Encode request: {encode_request.model_dump_json()}")
            encode_generator = await self.encode_worker_client.round_robin(
                encode_request.model_dump_json()
            )

            async for encode_response in encode_generator:
                encode_output = EncodeResponse.model_validate_json(
                    encode_response.data()
                )
                logger.info(f"Received response: {{ id: {encode_output.request_id} }}")

            # Wait for the write operation to complete.
            # This will block until the write operation is complete.
            # This await should be a no-op since we've already received a response from the encode worker.
            await writable.wait_for_completion()
            # At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.

        remote_prefill_params = None
        logger.debug(
            f"Prefilling locally for request {{ id: {request.request_id} }} with length {len(request.engine_prompt['prompt_token_ids'])}"
        )
        prompt_ids = request.engine_prompt["prompt_token_ids"]

        logger.debug(
            "Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
        )
        # When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
        multi_modal_data = {"image": embeddings}

        return prompt_ids, multi_modal_data, remote_prefill_params

    async def remote_prefill(self, request: vLLMMultimodalRequest) -> tuple:
        """
        Handles remote prefill in disaggregated serving mode.

        Creates remote prefill parameters and inserts dummy tokens for proper
        memory allocation. No direct encode worker interaction is required.

        Args:
            request: The multimodal request containing image URL and prompt data

        Returns:
            Tuple of (prompt_ids, multi_modal_data, remote_prefill_params)
        """
        logger.debug(
            f"Disaggregated: request {{ id: {request.request_id} }}"
            " prefill worker will populate the decode model's key-value cache ahead of time;"
            " no direct encode worker interaction required."
        )
        if self.disaggregated_router is not None:
            async with PrefillQueue.get_instance(
                nats_server=self._prefill_queue_nats_server,
                stream_name=self._prefill_queue_stream_name,
            ) as prefill_queue:
                prefill_queue_size = await prefill_queue.get_queue_size()
            disagg_router_decision = await self.disaggregated_router.prefill_remote(
                len(request.engine_prompt["prompt_token_ids"]),
                request.prefix_hit_rate,
                prefill_queue_size,
            )
        else:
            # always prefill remotely if no disaggregated router is provided
            disagg_router_decision = True

        if self.do_remote_prefill and disagg_router_decision:
            logger.debug(
                f"Prefilling remotely for request {{ id: {request.request_id} }} with length {len(request.engine_prompt['prompt_token_ids'])}"
            )
            remote_prefill_params = RemotePrefillParams(
                is_remote_prefill=True,
                remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
                # Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest
                multimodal_data_source={
                    "image_url": request.image_url,
                },
            )
        else:
            remote_prefill_params = None
            logger.debug(
                f"Prefilling locally for request {{ id: {request.request_id} }} with length {len(request.engine_prompt['prompt_token_ids'])}"
            )

        # The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
        # As a workaround, here we manually insert some placeholder dummy tokens based on the embedding size
        # so that decode worker can pre-allocate the memory with the correct size.
        # The structure of the prompt will be like: "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:".
        # Since the "<image>" token is included in the prompt, only need to insert (embedding_size - 1) dummy tokens after the image token.
        IMAGE_TOKEN_ID = 32000
        DUMMY_TOKEN_ID = 0
        # Find the index of the image token in the prompt token ids
        image_token_index = request.engine_prompt["prompt_token_ids"].index(
            IMAGE_TOKEN_ID
        )
        dummy_token_index = image_token_index + 1
        prompt_ids = (
            request.engine_prompt["prompt_token_ids"][:dummy_token_index]
            + [DUMMY_TOKEN_ID] * (self.embedding_size - 1)
            + request.engine_prompt["prompt_token_ids"][dummy_token_index:]
        )
        logger.debug(
            "Disaggregated: no embedding data required as prefill will have provided key-value cache updates via encode worker."
        )
        # When using aggregated serving, the encode worker will have provided the key-value cache updates via the prefill worker.
        multi_modal_data = None

        return prompt_ids, multi_modal_data, remote_prefill_params