main.py 29.3 KB
Newer Older
Alec's avatar
Alec committed
1
2
3
4
5
6
7
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging
import os
import signal
8
import tempfile
9
from typing import Optional
Alec's avatar
Alec committed
10
11

import uvloop
12
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
Alec's avatar
Alec committed
13
14
15
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
16
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
Alec's avatar
Alec committed
17

18
from dynamo.common.config_dump import dump_config
19
from dynamo.common.utils.endpoint_types import parse_endpoint_types
20
from dynamo.common.utils.prometheus import register_engine_metrics_callback
Alec's avatar
Alec committed
21
from dynamo.llm import (
22
    ModelInput,
23
    ModelRuntimeConfig,
Alec's avatar
Alec committed
24
25
26
    ModelType,
    ZmqKvEventPublisher,
    ZmqKvEventPublisherConfig,
27
    fetch_llm,
Alec's avatar
Alec committed
28
29
    register_llm,
)
30
from dynamo.runtime import DistributedRuntime
Alec's avatar
Alec committed
31
from dynamo.runtime.logging import configure_dynamo_logging
32
33
from dynamo.vllm.multimodal_handlers import (
    EncodeWorkerHandler,
Ayush Agarwal's avatar
Ayush Agarwal committed
34
    MultimodalDecodeWorkerHandler,
35
36
37
    MultimodalPDWorkerHandler,
    ProcessorHandler,
)
Alec's avatar
Alec committed
38

39
from .args import Config, overwrite_args, parse_args
Alec's avatar
Alec committed
40
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
41
from .health_check import VllmHealthCheckPayload, VllmPrefillHealthCheckPayload
Alec's avatar
Alec committed
42
43
from .publisher import StatLoggerFactory

Alec's avatar
Alec committed
44
45
46
47
48
49
configure_dynamo_logging()
logger = logging.getLogger(__name__)


async def graceful_shutdown(runtime):
    """
50
51
52
53
    Shutdown dynamo distributed runtime.
    The endpoints will be immediately invalidated so no new requests will be accepted.
    For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
    For endpoints served with graceful_shutdown=False, the serving function will return immediately.
Alec's avatar
Alec committed
54
55
56
57
58
59
    """
    logging.info("Received shutdown signal, shutting down DistributedRuntime")
    runtime.shutdown()
    logging.info("DistributedRuntime shutdown complete")


60
async def worker():
Alec's avatar
Alec committed
61
62
    config = parse_args()

63
    loop = asyncio.get_running_loop()
64
    runtime = DistributedRuntime(loop, config.store_kv, config.request_plane)
65

66
67
    overwrite_args(config)

Alec's avatar
Alec committed
68
69
70
71
72
73
74
    # Set up signal handler for graceful shutdown
    def signal_handler():
        asyncio.create_task(graceful_shutdown(runtime))

    for sig in (signal.SIGTERM, signal.SIGINT):
        loop.add_signal_handler(sig, signal_handler)

75
    logging.debug("Signal handlers set up for graceful shutdown")
Alec's avatar
Alec committed
76

77
    dump_config(config.dump_config_to, config)
78
79
80
81
82
83
84
85
86

    # Download the model if necessary.
    # register_llm would do this for us, but we want it on disk before we start vllm.
    # Ensure the original HF name (e.g. "Qwen/Qwen3-0.6B") is used as the served_model_name.
    if not config.served_model_name:
        config.served_model_name = config.engine_args.served_model_name = config.model
    if not os.path.exists(config.model):
        config.model = config.engine_args.model = await fetch_llm(config.model)

87
88
89
90
91
92
93
    # Route to appropriate initialization based on config flags
    if config.multimodal_processor:
        await init_multimodal_processor(runtime, config)
        logger.debug("init_multimodal_processor completed")
    elif config.multimodal_encode_worker:
        await init_multimodal_encode_worker(runtime, config)
        logger.debug("init_multimodal_encode_worker completed")
Ayush Agarwal's avatar
Ayush Agarwal committed
94
95
96
97
98
    elif (
        config.multimodal_worker
        or config.multimodal_decode_worker
        or config.multimodal_encode_prefill_worker
    ):
99
100
101
        await init_multimodal_worker(runtime, config)
        logger.debug("init_multimodal_worker completed")
    elif config.is_prefill_worker:
Alec's avatar
Alec committed
102
        await init_prefill(runtime, config)
103
        logger.debug("init_prefill completed")
Alec's avatar
Alec committed
104
105
    else:
        await init(runtime, config)
106
107
108
        logger.debug("init completed")

    logger.debug("Worker function completed, exiting...")
Alec's avatar
Alec committed
109
110


111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def setup_metrics_collection(config: Config, generate_endpoint, logger):
    """Set up metrics collection for vLLM and LMCache metrics.

    In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored:
      1. In-memory: Metric objects in global REGISTRY
      2. On-disk: Metric values in .db files (PROMETHEUS_MULTIPROC_DIR)

    MultiProcessCollector reads from .db files but adding it to REGISTRY can fail
    with "Duplicated timeseries" if PROMETHEUS_MULTIPROC_DIR was set before process
    started (K8s deployments) because metrics are already in REGISTRY.

    Solution: Try adding MultiProcessCollector to REGISTRY. If that fails, use
    separate registry for multiprocess collection and register callbacks to both
    registries to ensure all metrics (vllm, lmcache, dynamo_component) are collected.
    """
    if config.engine_args.disable_log_stats is False:
        if os.environ.get("PROMETHEUS_MULTIPROC_DIR"):
            try:
                # MultiProcessCollector reads metrics from .db files in PROMETHEUS_MULTIPROC_DIR
                # Adding it to REGISTRY allows collecting both in-memory and .db file metrics
                multiprocess.MultiProcessCollector(REGISTRY)
                logger.debug("Added MultiProcessCollector to global REGISTRY")
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
                    metric_prefix_filters=["vllm:", "lmcache:"],
                )
            except ValueError as e:
                # Conflict: metrics already in REGISTRY, MultiProcessCollector tries to add same metrics from .db files
                # Solution: Use separate registry that ONLY reads from .db files (no in-memory conflicts)
                logger.debug(
                    f"Could not add MultiProcessCollector to REGISTRY ({e}), using separate registry"
                )
                multiproc_registry = CollectorRegistry()
                multiprocess.MultiProcessCollector(multiproc_registry)

                # Register both registries to collect all metrics
                # Global REGISTRY has in-memory metrics (vllm, dynamo_component)
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
                    metric_prefix_filters=["vllm:", "dynamo_component:"],
                )
                # Multiproc registry has .db file metrics (lmcache, possibly vllm duplicates)
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=multiproc_registry,
                    metric_prefix_filters=["vllm:", "lmcache:"],
                )
        else:
            # No multiprocess mode
            register_engine_metrics_callback(
                endpoint=generate_endpoint,
                registry=REGISTRY,
                metric_prefix_filters=["vllm:", "lmcache:"],
            )


Yan Ru Pei's avatar
Yan Ru Pei committed
169
170
171
172
173
def setup_kv_event_publisher(
    config: Config,
    component,
    generate_endpoint,
    vllm_config,
174
175
176
    consolidator_enabled: bool = False,
    consolidator_port: Optional[int] = 5558,
) -> Optional[ZmqKvEventPublisher]:
Yan Ru Pei's avatar
Yan Ru Pei committed
177
    """
Yan Ru Pei's avatar
Yan Ru Pei committed
178
179
    Set up KV event publishers for prefix caching if enabled.
    Creates one publisher per dp_rank since each dp_rank publishes to a different port.
180
181
182
183
184
185
186
187
    Args:
        config: Worker configuration
        component: Component for runtime integration
        generate_endpoint: Endpoint for worker ID
        vllm_config: vLLM configuration
        consolidator_enabled: If True, subscribe to kv eventconsolidator's ZMQ endpoint
        consolidator_port: Port where kv event consolidator publishes (default: 5558)

Yan Ru Pei's avatar
Yan Ru Pei committed
188
    Returns:
Yan Ru Pei's avatar
Yan Ru Pei committed
189
        List of ZmqKvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
Yan Ru Pei's avatar
Yan Ru Pei committed
190
191
192
193
    """
    if not config.engine_args.enable_prefix_caching:
        return None

194
195
196
197
198
    # Skip KV event publishing for decode workers
    if config.is_decode_worker:
        logger.info("Skipping KV event publisher setup for decode worker")
        return None

199
200
201
    if config.engine_args.kv_events_config is None:
        return None

Yan Ru Pei's avatar
Yan Ru Pei committed
202
203
204
205
206
    # Get data_parallel_size to create publishers for all dp_ranks
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    kv_publishers = []

    for dp_rank in range(data_parallel_size):
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        if consolidator_enabled:
            # TODO: Use different port for each dp_rank once KVBM supports DP
            zmq_endpoint = f"tcp://127.0.0.1:{consolidator_port}"
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to consolidator at {zmq_endpoint}"
            )
        else:
            # Each dp_rank publishes to a different port
            zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
                config.engine_args.kv_events_config.endpoint,
                data_parallel_rank=dp_rank,
            ).replace("*", "127.0.0.1")
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to vLLM at {zmq_endpoint}"
            )
Yan Ru Pei's avatar
Yan Ru Pei committed
222
223

        zmq_config = ZmqKvEventPublisherConfig(
224
            worker_id=generate_endpoint.connection_id(),
Yan Ru Pei's avatar
Yan Ru Pei committed
225
226
227
228
229
            kv_block_size=vllm_config.cache_config.block_size,
            zmq_endpoint=zmq_endpoint,
        )
        kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
        kv_publishers.append(kv_publisher)
Yan Ru Pei's avatar
Yan Ru Pei committed
230

Yan Ru Pei's avatar
Yan Ru Pei committed
231
232
233
        logger.info(
            f"Worker reading KV events for dp_rank={dp_rank} from {zmq_endpoint}"
        )
Yan Ru Pei's avatar
Yan Ru Pei committed
234

Yan Ru Pei's avatar
Yan Ru Pei committed
235
    return kv_publishers if kv_publishers else None
Yan Ru Pei's avatar
Yan Ru Pei committed
236
237


Alec's avatar
Alec committed
238
def setup_vllm_engine(config, stat_logger=None):
239
240
241
    # vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
    # instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
    # ourselves to avoid this and handle cleanup properly.
242
243
244
245
246
247
248
249
    prometheus_temp_dir = None
    if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
        prometheus_temp_dir = tempfile.TemporaryDirectory(prefix="vllm_prometheus_")
        os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_temp_dir.name
        logger.debug(
            f"Created PROMETHEUS_MULTIPROC_DIR at: {os.environ['PROMETHEUS_MULTIPROC_DIR']}"
        )

250
    setup_multiprocess_prometheus()  # call vLLM's library's function to setup multiprocess prometheus
251
252
253
254
    logger.debug(
        f"Prometheus multiproc dir set to: {os.environ.get('PROMETHEUS_MULTIPROC_DIR')}"
    )

Alec's avatar
Alec committed
255
256
257
258
    os.environ["VLLM_NO_USAGE_STATS"] = "1"  # Avoid internal HTTP requests
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

    engine_args = config.engine_args
259

260
261
262
263
264
    if engine_args.enable_lora:
        if "VLLM_ALLOW_RUNTIME_LORA_UPDATING" not in os.environ:
            os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
        if "VLLM_LORA_MODULES_LOADING_TIMEOUT" not in os.environ:
            os.environ["VLLM_LORA_MODULES_LOADING_TIMEOUT"] = "600"
Alec's avatar
Alec committed
265
266
267
268
269
270
271
272
273
    # Load default sampling params from `generation_config.json`
    default_sampling_params = (
        engine_args.create_model_config().get_diff_sampling_param()
    )

    # Taken from build_async_engine_client_from_engine_args()
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

274
275
276
    # Set up consolidator endpoints if KVBM is enabled
    consolidator_endpoints = None
    if config.has_connector("kvbm"):
277
278
279
280
281
282
283
284
285
286
287
288
        try:
            from kvbm.vllm_integration.consolidator_config import (
                get_consolidator_endpoints,
            )

            consolidator_endpoints = get_consolidator_endpoints(vllm_config)
        except Exception as e:
            logger.warning(
                f"KVBM connector is enabled but failed to get consolidator endpoints: {e}. "
                "Continuing without KV event consolidation. "
                "Ensure 'kvbm' package is installed if this feature is needed."
            )
289
290
    vllm_config.consolidator_endpoints = consolidator_endpoints

Alec's avatar
Alec committed
291
292
293
294
295
296
297
298
299
300
301
    factory = []
    if stat_logger:
        factory.append(stat_logger)

    engine_client = AsyncLLM.from_vllm_config(
        vllm_config=vllm_config,
        usage_context=usage_context,
        stat_loggers=factory,
        disable_log_requests=engine_args.disable_log_requests,
        disable_log_stats=engine_args.disable_log_stats,
    )
302
303

    logger.info(f"VllmWorker for {config.served_model_name} has been initialized")
304

305
    return engine_client, vllm_config, default_sampling_params, prometheus_temp_dir
Alec's avatar
Alec committed
306
307


308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
async def register_vllm_model(
    model_input: ModelInput,
    model_type: ModelType,
    generate_endpoint,
    config: Config,
    engine_client: AsyncLLM,
    vllm_config,
    migration_limit: int,
):
    """
    Helper function to register a vLLM model with runtime configuration.

    Args:
        model_input: Input type for the model (e.g., ModelInput.Tokens)
        model_type: Type of model (e.g., ModelType.Chat, ModelType.Prefill)
        generate_endpoint: Endpoint to register
        config: Configuration object
        engine_client: vLLM engine client
        vllm_config: vLLM configuration
        migration_limit: Migration limit for the model
    """
    runtime_config = ModelRuntimeConfig()

    # Get runtime configuration from vLLM engine
    logging.info(
        f"Getting engine runtime configuration metadata from vLLM engine for {model_type}..."
    )
    runtime_values = get_engine_cache_info(engine_client)
    runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
    runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
    runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]

    # Add tool/reasoning parsers for decode models
    if model_type != ModelType.Prefill:
        runtime_config.tool_call_parser = config.tool_call_parser
        runtime_config.reasoning_parser = config.reasoning_parser

    # Get data_parallel_size from vllm_config (defaults to 1)
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    runtime_config.data_parallel_size = data_parallel_size

    await register_llm(
        model_input,
        model_type,
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
        migration_limit=migration_limit,
        runtime_config=runtime_config,
        custom_template_path=config.custom_jinja_template,
    )


Alec's avatar
Alec committed
362
363
364
365
async def init_prefill(runtime: DistributedRuntime, config: Config):
    """
    Instantiate and serve
    """
366
    component = runtime.namespace(config.namespace).component(config.component)
Alec's avatar
Alec committed
367
368
369
370

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")

371
372
373
374
375
376
    (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
    ) = setup_vllm_engine(config)
Alec's avatar
Alec committed
377

378
    handler = PrefillWorkerHandler(
379
380
381
382
383
        runtime,
        component,
        engine_client,
        default_sampling_params,
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
384
        enable_multimodal=config.enable_multimodal,
385
386
        generate_endpoint=generate_endpoint,
        config=config,
387
    )
388
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
389

390
391
392
393
394
395
396
397
398
399
400
401
402
403
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

Yan Ru Pei's avatar
Yan Ru Pei committed
404
    # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
405
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
406
    kv_publishers = setup_kv_event_publisher(
407
408
409
410
411
412
        config,
        component,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
413
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
414
415
    if kv_publishers:
        handler.kv_publishers = kv_publishers
416

417
    setup_metrics_collection(config, generate_endpoint, logger)
418

419
420
421
422
423
424
425
426
427
428
429
430
    # Register prefill model with ModelType.Prefill
    if not config.engine_args.data_parallel_rank:  # if rank is 0 or None then register
        await register_vllm_model(
            ModelInput.Tokens,
            ModelType.Prefill,
            generate_endpoint,
            config,
            engine_client,
            vllm_config,
            migration_limit=0,  # Prefill doesn't support migration
        )

431
    health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict()
432

Alec's avatar
Alec committed
433
    try:
434
        logger.debug("Starting serve_endpoint for prefill worker")
Alec's avatar
Alec committed
435
        await asyncio.gather(
436
437
438
439
            # for prefill, we want to shutdown the engine after all prefill requests are finished because
            #     (temp reason): we don't support re-routing prefill requests
            #     (long-term reason): prefill engine should pull from a global queue so there is
            #                         only a few in-flight requests that can be quickly finished
440
441
442
            generate_endpoint.serve_endpoint(
                handler.generate,
                graceful_shutdown=True,
443
444
                # In practice config.served_model_name is always set, but mypy needs the "or" here.
                metrics_labels=[("model", config.served_model_name or config.model)],
445
                health_check_payload=health_check_payload,
446
447
            ),
            clear_endpoint.serve_endpoint(
448
449
                handler.clear_kv_blocks,
                metrics_labels=[("model", config.served_model_name)],
450
            ),
Alec's avatar
Alec committed
451
        )
452
        logger.debug("serve_endpoint completed for prefill worker")
Alec's avatar
Alec committed
453
454
455
456
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
457
        logger.debug("Cleaning up prefill worker")
Alec's avatar
Alec committed
458
459
460
461
462
463
464
465
        handler.cleanup()


async def init(runtime: DistributedRuntime, config: Config):
    """
    Instantiate and serve
    """

466
    component = runtime.namespace(config.namespace).component(config.component)
Alec's avatar
Alec committed
467
468
469

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")
470
471
472
    load_lora_endpoint = component.endpoint("load_lora")
    unload_lora_endpoint = component.endpoint("unload_lora")
    list_loras_endpoint = component.endpoint("list_loras")
Alec's avatar
Alec committed
473

474
475
476
    factory = StatLoggerFactory(
        component,
        config.engine_args.data_parallel_rank or 0,
477
        metrics_labels=[("model", config.served_model_name or config.model)],
478
    )
479
480
481
482
483
484
    (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
    ) = setup_vllm_engine(config, factory)
Alec's avatar
Alec committed
485

486
    # TODO Hack to get data, move this to registering in TBD
Alec's avatar
Alec committed
487
488
489
490
491
    factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
    factory.set_request_total_slots_all(vllm_config.scheduler_config.max_num_seqs)
    factory.init_publish()

    handler = DecodeWorkerHandler(
492
493
494
495
        runtime,
        component,
        engine_client,
        default_sampling_params,
496
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
497
        enable_multimodal=config.enable_multimodal,
498
499
        generate_endpoint=generate_endpoint,
        config=config,
Alec's avatar
Alec committed
500
    )
501
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
502

503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

    # Set up KV event publisher for prefix caching if enabled
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
519
    kv_publishers = setup_kv_event_publisher(
520
521
522
523
524
525
        config,
        component,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
Yan Ru Pei's avatar
Yan Ru Pei committed
526
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
527
528
    if kv_publishers:
        handler.kv_publishers = kv_publishers
529

530
    setup_metrics_collection(config, generate_endpoint, logger)
531

532
    if not config.engine_args.data_parallel_rank:  # if rank is 0 or None then register
533
534
535
536
537
538
539
540
541
542
543
544
545
        # Parse endpoint types from --dyn-endpoint-types flag
        model_type = parse_endpoint_types(config.dyn_endpoint_types)
        logger.info(
            f"Registering model with endpoint types: {config.dyn_endpoint_types}"
        )

        # Warn if custom template provided but chat endpoint not enabled
        if config.custom_jinja_template and "chat" not in config.dyn_endpoint_types:
            logger.warning(
                "Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
                "The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
            )

546
        await register_vllm_model(
547
            ModelInput.Tokens,
548
            model_type,
549
            generate_endpoint,
550
551
552
            config,
            engine_client,
            vllm_config,
553
554
555
            migration_limit=config.migration_limit,
        )

556
    health_check_payload = VllmHealthCheckPayload(engine_client).to_dict()
557

Alec's avatar
Alec committed
558
    try:
559
        logger.debug("Starting serve_endpoint for decode worker")
Alec's avatar
Alec committed
560
        await asyncio.gather(
561
562
            # for decode, we want to transfer the in-flight requests to other decode engines,
            # because waiting them to finish can take a long time for long OSLs
563
564
            generate_endpoint.serve_endpoint(
                handler.generate,
565
                graceful_shutdown=config.migration_limit <= 0,
566
                metrics_labels=[("model", config.served_model_name or config.model)],
567
                health_check_payload=health_check_payload,
568
569
            ),
            clear_endpoint.serve_endpoint(
570
571
                handler.clear_kv_blocks,
                metrics_labels=[("model", config.served_model_name or config.model)],
572
            ),
573
574
575
576
577
578
579
580
581
582
583
584
            load_lora_endpoint.serve_endpoint(
                handler.load_lora,
                metrics_labels=[("model", config.served_model_name or config.model)],
            ),
            unload_lora_endpoint.serve_endpoint(
                handler.unload_lora,
                metrics_labels=[("model", config.served_model_name or config.model)],
            ),
            list_loras_endpoint.serve_endpoint(
                handler.list_loras,
                metrics_labels=[("model", config.served_model_name or config.model)],
            ),
Alec's avatar
Alec committed
585
        )
586
        logger.debug("serve_endpoint completed for decode worker")
Alec's avatar
Alec committed
587
588
589
590
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
591
        logger.debug("Cleaning up decode worker")
Alec's avatar
Alec committed
592
593
594
595
        # Cleanup background tasks
        handler.cleanup()


596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
def get_engine_cache_info(engine: AsyncLLM):
    """Retrieve cache configuration information from [`AsyncLLM`] engine."""

    try:
        # Get values directly from vllm_config instead of collective_rpc
        cache_values = {
            "num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks,
        }

        scheduler_values = {
            "max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs,
            "max_num_batched_tokens": engine.vllm_config.scheduler_config.max_num_batched_tokens,
        }

        logging.info(f"Cache config values: {cache_values}")
        logging.info(f"Scheduler config values: {scheduler_values}")
        return {
            "num_gpu_blocks": cache_values["num_gpu_blocks"],
            "max_num_seqs": scheduler_values["max_num_seqs"],
            "max_num_batched_tokens": scheduler_values["max_num_batched_tokens"],
        }
    except Exception as e:
        logging.error(f"Failed to get configuration values from vLLM config: {e}")
        raise


622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
async def init_multimodal_processor(runtime: DistributedRuntime, config: Config):
    """Initialize multimodal processor component"""
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)

    # Get encode worker client
    encode_worker_client = (
        await runtime.namespace(config.namespace)
        .component("encoder")
        .endpoint("generate")
        .client()
    )

    # Get prompt template from args (must be passed via environment or command line)
    mm_prompt_template = config.mm_prompt_template

    handler = ProcessorHandler(
        config.engine_args,
        encode_worker_client,
        mm_prompt_template,
    )

    logger.info("Waiting for Encoder Worker Instances ...")
    await encode_worker_client.wait_for_instances()

    # Register the endpoint as entrypoint to a model
    await register_llm(
        ModelInput.Text,  # Custom processor is used and this type bypasses SDK processor
        ModelType.Chat,
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
    )

    logger.info("Starting to serve the processor endpoint...")

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
                handler.generate, metrics_labels=[("model", config.model)]
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
        handler.cleanup()


async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Config):
    """Initialize multimodal encode worker component"""
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)

    # Get PD worker client
    # In multimodal mode, the PD worker always registers as "backend"
    # (even in disaggregated mode with prefill/decode split, we still connect to "backend")
    pd_worker_client = (
        await runtime.namespace(config.namespace)
        .component("backend")
        .endpoint("generate")
        .client()
    )

    handler = EncodeWorkerHandler(
        config.engine_args,
        pd_worker_client,
    )
    await handler.async_init(runtime)
    logger.info("Waiting for PD Worker Instances ...")
    await pd_worker_client.wait_for_instances()
    logger.info("Starting to serve the encode worker endpoint...")

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
                handler.generate, metrics_labels=[("model", config.model)]
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
        handler.cleanup()


async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
712
713
714
715
716
717
    """
    Initialize multimodal worker component.

    Supports two modes:
    1. --multimodal-worker: Receives embeddings from separate encoder
    2. --multimodal-encode-prefill-worker: Handles inline encoding (e.g., Llama 4)
718

719
720
    Both can operate in aggregated (P+D) or disaggregated (P→D) mode.
    """
721
722
723
724
725
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")

726
727
728
729
730
731
    (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
    ) = setup_vllm_engine(config)
732

Ayush Agarwal's avatar
Ayush Agarwal committed
733
734
735
736
737
738
739
740
741
742
743
744
    # Set up decode worker client for disaggregated mode
    decode_worker_client = None
    if config.is_prefill_worker:
        # Prefill worker needs to connect to decode worker
        decode_worker_client = (
            await runtime.namespace(config.namespace)
            .component("decoder")
            .endpoint("generate")
            .client()
        )
        await decode_worker_client.wait_for_instances()
        logger.info("Connected to decode worker for disaggregated mode")
745

Ayush Agarwal's avatar
Ayush Agarwal committed
746
747
748
749
750
751
752
753
754
    # Choose handler based on worker type
    if config.multimodal_decode_worker:
        handler = MultimodalDecodeWorkerHandler(
            runtime, component, engine_client, config
        )
    else:
        handler = MultimodalPDWorkerHandler(
            runtime, component, engine_client, config, decode_worker_client
        )
755
    handler.add_temp_dir(prometheus_temp_dir)
756
757
758
759
760
761
762
763
764
765
766
767
768
769

    await handler.async_init(runtime)

    # Set up KV event publisher for prefix caching if enabled
    kv_publisher = setup_kv_event_publisher(
        config, component, generate_endpoint, vllm_config
    )
    if kv_publisher:
        handler.kv_publisher = kv_publisher

    metrics_labels = [("model", config.model)]
    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
770
771
                handler.generate,
                metrics_labels=metrics_labels,
772
773
            ),
            clear_endpoint.serve_endpoint(
774
775
                handler.clear_kv_blocks,
                metrics_labels=metrics_labels,
776
777
778
779
780
781
782
783
784
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
        handler.cleanup()


Alec's avatar
Alec committed
785
786
787
788
def main():
    uvloop.run(worker())


Alec's avatar
Alec committed
789
if __name__ == "__main__":
Alec's avatar
Alec committed
790
    main()