worker.py 17.8 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
# SPDX-License-Identifier: Apache-2.0

4
5
6
7
8
import os

if "PYTHONHASHSEED" not in os.environ:
    os.environ["PYTHONHASHSEED"] = "0"

9
10
11
12
13
14
15
16
17
18
19
20
21
import argparse
import asyncio
import copy
import logging
import signal
import sys
from typing import Tuple

import torch
import uvloop
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.inputs.data import TokensPrompt
from vllm.usage.usage_lib import UsageContext
22
from vllm.utils.argparse_utils import FlexibleArgumentParser
23
24
from vllm.v1.engine.async_llm import AsyncLLM

25
import dynamo.nixl_connect as connect
26
from dynamo.llm import KvEventPublisher
27
from dynamo.runtime import DistributedRuntime, Endpoint, dynamo_worker
28
29
30
31
32
33
34
from dynamo.runtime.logging import configure_dynamo_logging

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from publisher import StatLoggerFactory
from utils.args import (
    Config,
    base_parse_args,
35
    configure_ports,
36
37
38
39
    overwrite_args,
    parse_endpoint,
)
from utils.image_loader import ImageLoader
40
from utils.model import construct_mm_data
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest

configure_dynamo_logging()
logger = logging.getLogger(__name__)


class VllmBaseWorker:
    @classmethod
    def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
        parser = FlexibleArgumentParser(
            description="vLLM based encoder for Dynamo LLM."
        )
        parser.add_argument(
            "--endpoint",
            type=str,
            help="Dynamo endpoint string in 'dyn://namespace.component.endpoint' format.  Default value will vary based on the worker type, see --worker-type for details.",
        )
        parser.add_argument(
            "--downstream-endpoint",
            type=str,
            help="The endpoint string of the downstream LLM in 'dyn://namespace.component.endpoint' format. Default value will vary based on the worker type, see --worker-type for details.",
        )
        parser.add_argument(
            "--worker-type",
            type=str,
            choices=["prefill", "decode", "encode_prefill"],
            required=True,
            help="Specify the type of worker. Must be one of: 'prefill', 'decode', 'encode_prefill'",
        )
        parser.add_argument(
            "--enable-disagg",
            action="store_true",
            help="Enable disaggregated mode, where prefill and decode are handled by separate workers."
            " If not set, the '*prefill' worker type will handle both prefill and decode.",
        )

        # use endpoint_overwrite to set the default endpoint based on worker type
        def endpoint_overwrite(args):
79
            DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
80
81
            # default endpoint for this worker
            if args.worker_type == "prefill":
82
                args.endpoint = args.endpoint or f"dyn://{DYN_NAMESPACE}.llm.generate"
83
            elif args.worker_type == "decode":
84
85
86
                args.endpoint = (
                    args.endpoint or f"dyn://{DYN_NAMESPACE}.decoder.generate"
                )
87
            elif args.worker_type == "encode_prefill":
88
89
90
                args.endpoint = (
                    args.endpoint or f"dyn://{DYN_NAMESPACE}.encoder.generate"
                )
91
92
93
            # set downstream endpoint for disaggregated workers
            if args.enable_disagg:
                args.downstream_endpoint = (
94
95
                    args.downstream_endpoint
                    or f"dyn://{DYN_NAMESPACE}.decoder.generate"
96
97
98
99
100
101
102
103
104
105
106
107
                )

            return args

        args, config = base_parse_args(parser, endpoint_overwrite)

        return args, config

    def __init__(
        self,
        args: argparse.Namespace,
        endpoint: Endpoint,
108
        config: Config,
109
110
111
112
    ):
        self.enable_disagg = args.enable_disagg
        self.endpoint = args.endpoint
        self.downstream_endpoint = args.downstream_endpoint
113
114
        self.engine_args = config.engine_args
        self.config = config
115
        self.setup_vllm_engine(endpoint)
116
117
118
119

    async def async_init(self, runtime: DistributedRuntime):
        pass

120
    def setup_vllm_engine(self, endpoint: Endpoint):
121
122
        """Initialize the vLLM engine.
        This method sets up the vLLM engine client, and configures the dynamo-aware KV
123
        event publisher and metrics stats logger based on endpoint.
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        """

        os.environ["VLLM_NO_USAGE_STATS"] = "1"  # Avoid internal HTTP requests
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        # Load default sampling params from `generation_config.json`
        self.default_sampling_params = (
            self.engine_args.create_model_config().get_diff_sampling_param()
        )

        # Taken from build_async_engine_client_from_engine_args()
        usage_context = UsageContext.OPENAI_API_SERVER
        vllm_config = self.engine_args.create_engine_config(usage_context=usage_context)

        # Create vLLM engine with metrics logger and KV event publisher attached
        self.stats_logger = StatLoggerFactory(
140
141
            endpoint=endpoint,
            dp_rank=self.engine_args.data_parallel_rank or 0,
142
143
144
145
146
        )
        self.engine_client = AsyncLLM.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            stat_loggers=[self.stats_logger],
147
            enable_log_requests=self.engine_args.enable_log_requests,
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
            disable_log_stats=self.engine_args.disable_log_stats,
        )

        # TODO Hack to get data, move this to registering in ETCD
        self.stats_logger.set_num_gpu_blocks_all(
            vllm_config.cache_config.num_gpu_blocks
        )
        self.stats_logger.init_publish()

        # TODO: We start off with a valid endpoint, then we increment it by dp_rank
        # May no longer be valid. Lets remove the increment behavior from vLLM and here
        zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
            self.engine_args.kv_events_config.endpoint,
            data_parallel_rank=self.engine_args.data_parallel_rank or 0,
        ).replace("*", "127.0.0.1")

164
        self.kv_publisher = KvEventPublisher(
165
            endpoint=endpoint,
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
            kv_block_size=vllm_config.cache_config.block_size,
            zmq_endpoint=zmq_endpoint,
        )

        logger.info(f"Reading Events from {zmq_endpoint}")

        logger.info(f"VllmWorker for {self.engine_args.model} has been initialized")

    async def generate(self, request: vLLMMultimodalRequest):
        raise NotImplementedError(
            "This method should be implemented in subclasses to handle the generation logic."
        )

    async def clear_kv_blocks(self, request=None):
        try:
            await self.engine_client.reset_prefix_cache()
            yield {"status": "success", "message": "KV cache cleared"}
        except Exception as e:
            yield {"status": "error", "message": str(e)}

    def cleanup(self):
        """Override in subclasses if cleanup is needed."""
        pass


class VllmDecodeWorker(VllmBaseWorker):
    async def generate(self, request: vLLMMultimodalRequest):
        logger.debug(f"Got raw request: {request}")
        if not isinstance(request, vLLMMultimodalRequest):
            if isinstance(request, str):
                request = vLLMMultimodalRequest.model_validate_json(request)
            else:
                request = vLLMMultimodalRequest.model_validate(request)
        logger.debug(f"Received decode request: {{ id: {request.request_id} }}.")

        # Decode worker doesn't process embeddings, so we pass None or empty tensor
        gen = self.engine_client.generate(
            prompt=TokensPrompt(
                prompt_token_ids=request.engine_prompt["prompt_token_ids"],
            ),
            sampling_params=request.sampling_params,
            request_id=request.request_id,
        )

        async for response in gen:
            logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}")
            yield MyRequestOutput(
                request_id=response.request_id,
                prompt=response.prompt,
                prompt_token_ids=response.prompt_token_ids,
                prompt_logprobs=response.prompt_logprobs,
                outputs=response.outputs,
                finished=response.finished,
                metrics=response.metrics,
                kv_transfer_params=response.kv_transfer_params,
            ).model_dump_json()


class VllmPDWorker(VllmBaseWorker):
    async def async_init(self, runtime: DistributedRuntime):
        logger.info("Startup started.")

        if self.enable_disagg:
            (
                parsed_namespace,
                parsed_component_name,
                parsed_endpoint_name,
            ) = parse_endpoint(self.downstream_endpoint)
234
235
236
            self.decode_worker_client = await runtime.endpoint(
                f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
            ).client()
237

238
239
240
241
242
        if "video" in self.engine_args.model.lower():
            self.EMBEDDINGS_DTYPE = torch.uint8
        else:
            self.EMBEDDINGS_DTYPE = torch.float16

243
        self.EMBEDDINGS_DEVICE = "cpu"
244

245
246
247
        # Create and initialize a dynamo connector for this worker.
        # We'll needs this to move data between this worker and remote workers efficiently.
        parsed_namespace, _, _ = parse_endpoint(self.endpoint)
248
        self._connector = connect.Connector()
249
250
251
252
253
254
255
256
257
258
259
260
261
262

        self.image_loader = ImageLoader()

        logger.info("VllmPDWorker has been initialized")

    async def generate(self, request: vLLMMultimodalRequest):
        logger.debug(f"Got raw request: {request}")
        if type(request) is not vLLMMultimodalRequest:
            if type(request) is str:
                request = vLLMMultimodalRequest.model_validate_json(request)
            else:
                request = vLLMMultimodalRequest.model_validate(request)
        logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")

263
264
265
        if (
            request.multimodal_input.image_url is None
            and request.multimodal_input.video_url is None
266
            and request.multimodal_input.audio_url is None
267
        ):
268
269
270
271
272
273
274
275
276
            # Process embeddings using the connector
            # Create a descriptor based on the embedding shape.
            embeddings = torch.empty(
                request.embeddings_shape,
                dtype=self.EMBEDDINGS_DTYPE,
                device=self.EMBEDDINGS_DEVICE,
            )
            descriptor = connect.Descriptor(embeddings)

277
278
279
280
281
282
283
284
285
            if descriptor is None:
                raise RuntimeError(
                    "Descriptor is None in PD worker - cannot process embeddings"
                )

            read_op = await self._connector.begin_read(
                request.serialized_request, descriptor
            )
            await read_op.wait_for_completion()
286
287
288
289
290
291
292
            if "video" in self.engine_args.model.lower():
                video_numpy = embeddings.numpy()
                multi_modal_data = construct_mm_data(
                    self.engine_args.model,
                    self.EMBEDDINGS_DTYPE,
                    video_numpy=video_numpy,
                )
293
294
295
296
297
298
            elif "audio" in self.engine_args.model.lower():
                multi_modal_data = construct_mm_data(
                    self.engine_args.model,
                    self.EMBEDDINGS_DTYPE,
                    audio_embeds=embeddings,
                )
299
300
301
302
303
304
305
            else:
                multi_modal_data = construct_mm_data(
                    self.engine_args.model,
                    self.EMBEDDINGS_DTYPE,
                    image_embeds=embeddings,
                    image_grid_thw=request.image_grid_thw,
                )
306
307
        else:
            # Use PIL image instead of image embeddings
308
            multi_modal_data = {
309
310
311
                "image": await self.image_loader.load_image(
                    request.multimodal_input.image_url
                )
312
            }
313
314

        # Remove the image features from the request as they are not required
315
316
        request.multimodal_input.image_url = None
        request.multimodal_input.video_url = None
317
        request.multimodal_input.audio_url = None
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        request.serialized_request = None

        pd_request = copy.deepcopy(request)
        # Do prefill and remote decode if enable_disagg is true
        if self.enable_disagg:
            extra_args = pd_request.sampling_params.extra_args or {}
            extra_args["kv_transfer_params"] = {
                "do_remote_decode": True,
            }
            pd_request.sampling_params.extra_args = extra_args
            pd_request.sampling_params.max_tokens = 1
            pd_request.sampling_params.min_tokens = 1

            logger.debug("Prefill request: %s", pd_request)

        gen = self.engine_client.generate(
            prompt=TokensPrompt(
                prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"],
336
                multi_modal_data=multi_modal_data,
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
            ),
            sampling_params=pd_request.sampling_params,
            request_id=pd_request.request_id,
        )

        if self.enable_disagg:
            decode_request = copy.deepcopy(request)
            async for prefill_response in gen:
                # Update the prompt token id in the decode request to the one
                # in response, which has image templated filled in. So that
                # the decode worker will fetch correct amount of KV blocks.
                decode_request.engine_prompt[
                    "prompt_token_ids"
                ] = prefill_response.prompt_token_ids
                logger.debug(
                    f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}"
                )
                extra_args = decode_request.sampling_params.extra_args or {}
                extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params
                extra_args.pop("serialized_request", None)
                decode_request.sampling_params.extra_args = extra_args
                logger.debug("Decode request: %s", decode_request)
359
360
361
                async for (
                    decode_response
                ) in await self.decode_worker_client.round_robin(
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
                    decode_request.model_dump_json()
                ):
                    output = MyRequestOutput.model_validate_json(decode_response.data())
                    yield MyRequestOutput(
                        request_id=output.request_id,
                        prompt=output.prompt,
                        prompt_token_ids=output.prompt_token_ids,
                        prompt_logprobs=output.prompt_logprobs,
                        outputs=output.outputs,
                        finished=output.finished,
                        metrics=output.metrics,
                        kv_transfer_params=output.kv_transfer_params,
                    ).model_dump_json()

        else:
            async for response in gen:
                logger.debug(
                    f"Response kv_transfer_params: {response.kv_transfer_params}"
                )
                yield MyRequestOutput(
                    request_id=response.request_id,
                    prompt=response.prompt,
                    prompt_token_ids=response.prompt_token_ids,
                    prompt_logprobs=response.prompt_logprobs,
                    outputs=response.outputs,
                    finished=response.finished,
                    metrics=response.metrics,
                    kv_transfer_params=response.kv_transfer_params,
                ).model_dump_json()


async def graceful_shutdown(runtime):
    """
    By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
    However, in-flight requests will still be processed until they are finished.
    After all in-flight requests are finished, the `serve_endpoint` functions will return
    and the engine will be shutdown by Python's garbage collector.
    """
    logging.info("Received shutdown signal, shutting down DistributedRuntime")
    runtime.shutdown()
    logging.info("DistributedRuntime shutdown complete")


405
@dynamo_worker()
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
async def worker(runtime: DistributedRuntime):
    # Runtime setup
    # Set up signal handler for graceful shutdown
    loop = asyncio.get_running_loop()

    def signal_handler():
        asyncio.create_task(graceful_shutdown(runtime))

    for sig in (signal.SIGTERM, signal.SIGINT):
        loop.add_signal_handler(sig, signal_handler)

    logging.info("Signal handlers set up for graceful shutdown")

    # worker setup
    args, config = VllmBaseWorker.parse_args()

    # vLLM config overwrites
423
    configure_ports(config)
424
425
426
427
428
429
430
431
432
    overwrite_args(config)
    await init(runtime, args, config)


async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
    """
    Instantiate and serve
    """

433
434
435
    generate_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.{config.endpoint}"
    )
436
437
438
    clear_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.clear_kv_blocks"
    )
439
440

    if args.worker_type in ["prefill", "encode_prefill"]:
441
        handler: VllmBaseWorker = VllmPDWorker(args, generate_endpoint, config)
442
    elif args.worker_type == "decode":
443
        handler = VllmDecodeWorker(args, generate_endpoint, config)
444
445
446
447
    await handler.async_init(runtime)

    logger.info(f"Starting to serve the {args.endpoint} endpoint...")

448
449
    metrics_labels = [("model", config.model)]

450
451
    try:
        await asyncio.gather(
452
453
454
455
456
457
            generate_endpoint.serve_endpoint(
                handler.generate, metrics_labels=metrics_labels
            ),
            clear_endpoint.serve_endpoint(
                handler.clear_kv_blocks, metrics_labels=metrics_labels
            ),
458
459
460
461
462
463
464
465
466
467
468
        )
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
        handler.cleanup()


if __name__ == "__main__":
    uvloop.install()
    asyncio.run(worker())