main.py 39.6 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Alec's avatar
Alec committed
2
3
# SPDX-License-Identifier: Apache-2.0

4
import argparse
Alec's avatar
Alec committed
5
6
7
import asyncio
import logging
import os
8
import tempfile
9
import time
10
from typing import Optional
Alec's avatar
Alec committed
11
12

import uvloop
13
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
14
from vllm.config import VllmConfig
Alec's avatar
Alec committed
15
from vllm.distributed.kv_events import ZmqEventPublisher
16
from vllm.entrypoints.cli.serve import run_headless
Alec's avatar
Alec committed
17
18
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
19
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
Alec's avatar
Alec committed
20

21
from dynamo import prometheus_names
22
from dynamo.common.config_dump import dump_config
23
from dynamo.common.storage import get_fs
24
from dynamo.common.utils.endpoint_types import parse_endpoint_types
25
from dynamo.common.utils.graceful_shutdown import install_signal_handlers
26
from dynamo.common.utils.output_modalities import get_output_modalities
27
28
29
30
from dynamo.common.utils.prometheus import (
    LLMBackendMetrics,
    register_engine_metrics_callback,
)
31
from dynamo.common.utils.runtime import create_runtime
Alec's avatar
Alec committed
32
from dynamo.llm import (
33
    KvEventPublisher,
34
    ModelInput,
35
    ModelRuntimeConfig,
Alec's avatar
Alec committed
36
    ModelType,
37
38
    fetch_model,
    register_model,
Alec's avatar
Alec committed
39
)
40
41
42
43
44
45
46
47
48
49
50

# Optional imports for frontend decoding support
try:
    from dynamo.llm import MediaDecoder, MediaFetcher

    MEDIA_DECODER_AVAILABLE = True
except ImportError:
    MediaDecoder = None
    MediaFetcher = None
    MEDIA_DECODER_AVAILABLE = False

51
from dynamo.runtime import DistributedRuntime, Endpoint
Alec's avatar
Alec committed
52
from dynamo.runtime.logging import configure_dynamo_logging
53
from dynamo.vllm.worker_factory import WorkerFactory
Alec's avatar
Alec committed
54

55
from .args import Config, parse_args
56
from .checkpoint_restore import get_checkpoint_config
57
from .constants import DisaggregationMode
Alec's avatar
Alec committed
58
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
59
60
61
62
63
from .health_check import (
    VllmHealthCheckPayload,
    VllmOmniHealthCheckPayload,
    VllmPrefillHealthCheckPayload,
)
64
from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
Alec's avatar
Alec committed
65

Alec's avatar
Alec committed
66
67
configure_dynamo_logging()
logger = logging.getLogger(__name__)
68
shutdown_endpoints: list = []
69
CHECKPOINT_SLEEP_MODE_LEVEL = 1
Alec's avatar
Alec committed
70
71


72
73
74
75
76
77
78
79
80
81
82
83
84
async def _handle_non_leader_node(dp_rank: int) -> None:
    """
    Handle non-leader node (data_parallel_rank >= 1) in multi-node deployments.
    Non-leader nodes run vLLM workers but don't serve Dynamo endpoints.
    """
    logger.info(
        f"Non-leader node detected (data_parallel_rank={dp_rank}). "
        "Skipping endpoint serving."
    )
    # Wait indefinitely - process terminated via signal handlers
    await asyncio.Event().wait()


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
def build_headless_namespace(config: Config) -> argparse.Namespace:
    """Build an argparse Namespace from engine_args for vLLM's run_headless().

    run_headless() expects the raw CLI namespace. We reconstruct it from
    the already-parsed AsyncEngineArgs so parse_args() doesn't need to
    leak transport details.
    """
    ns = argparse.Namespace(**vars(config.engine_args))
    # run_headless() reads api_server_count; default to 0 (no API server)
    if not hasattr(ns, "api_server_count"):
        ns.api_server_count = 0
    return ns


def run_dynamo_headless(config: Config) -> None:
    """Run in headless mode for multi-node TP/PP.

    Secondary nodes spawn vLLM workers only — no engine core, no scheduler,
    no Dynamo endpoints. Bypasses DistributedRuntime entirely (no NATS/etcd).
    """
    args = build_headless_namespace(config)
    run_headless(args)


109
async def worker():
Alec's avatar
Alec committed
110
111
    config = parse_args()

112
113
114
115
116
117
118
    dump_config(config.dump_config_to, config)

    # Name the model. Use either the full path (vllm and sglang do the same),
    # or the HF name (e.g. "Qwen/Qwen3-0.6B"), depending on cmd line params.
    if not config.served_model_name:
        config.served_model_name = config.engine_args.served_model_name = config.model

119
    # Check checkpoint mode and validate env vars EARLY (fail fast if misconfigured)
120
121
    early_exit, checkpoint_cfg = get_checkpoint_config()
    if early_exit:
122
        return
123
124
125
126

    # Download the model if necessary using modelexpress.
    # We want it on disk before we start vllm to avoid downloading from HuggingFace.
    #
127
    # We don't set `config.engine_args.model` to the local path fetch_model returns
128
129
130
131
132
133
    # because vllm will send that name to its Ray pipeline-parallel workers, which
    # may not have the local path.
    # vllm will attempt to download the model again, but find it in the HF cache.
    # For non-HF models use a path instead of an HF name, and ensure all workers have
    # that path (ideally via a shared folder).
    if not os.path.exists(config.model):
134
        await fetch_model(config.model)
135

136
137
138
139
140
141
142
143
144
145
146
147
    # HEADLESS MODE: bypass DistributedRuntime entirely.
    # Workers run vLLM only (no NATS, etcd, or dynamo endpoints).
    if config.headless:
        if checkpoint_cfg is not None:
            raise ValueError(
                "--headless is incompatible with checkpoint mode "
                "(DYN_CHECKPOINT_SIGNAL_FILE is set). "
                "Remove --headless or unset DYN_CHECKPOINT_SIGNAL_FILE."
            )
        run_dynamo_headless(config)
        return

148
149
150
    # CHECKPOINT MODE: Load engine BEFORE runtime creation
    # This allows checkpointing GPU state before runtime connections are established
    pre_created_engine = None
151
    if checkpoint_cfg is not None:
152
        logger.info("Checkpoint mode enabled (watcher-driven signals)")
153

154
155
156
        # Checkpoint mode requires sleep mode — enable before engine init
        config.engine_args.enable_sleep_mode = True

157
158
159
        pre_created_engine = setup_vllm_engine(config)
        engine_client = pre_created_engine[0]

160
161
162
        if not await checkpoint_cfg.run_lifecycle(
            engine_client, CHECKPOINT_SLEEP_MODE_LEVEL
        ):
163
            return
164

165
    shutdown_event = asyncio.Event()
166
    runtime, loop = create_runtime(
167
        discovery_backend=config.discovery_backend,
168
169
170
        request_plane=config.request_plane,
        event_plane=config.event_plane,
        use_kv_events=config.use_kv_events,
171
172
    )

173
174
    install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event)

175
    # Route to appropriate initialization based on config flags
176
177
178
179
180
181
182
183
    if WorkerFactory.handles(config):
        # Create worker factory with setup functions
        factory = WorkerFactory(
            setup_vllm_engine_fn=setup_vllm_engine,
            setup_kv_event_publisher_fn=setup_kv_event_publisher,
            register_vllm_model_fn=register_vllm_model,
        )
        await factory.create(
184
185
186
187
188
            runtime,
            config,
            shutdown_event,
            shutdown_endpoints,
            pre_created_engine=pre_created_engine,
189
        )
190
        logger.debug("multimodal worker completed")
191
192
193
    elif config.omni:
        await init_omni(runtime, config, shutdown_event)
        logger.debug("init_omni completed")
194
    elif config.disaggregation_mode == DisaggregationMode.PREFILL:
195
196
197
        await init_prefill(
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
198
        logger.debug("init_prefill completed")
Alec's avatar
Alec committed
199
    else:
200
201
202
        await init(
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
203
204
205
        logger.debug("init completed")

    logger.debug("Worker function completed, exiting...")
Alec's avatar
Alec committed
206
207


208
209
210
211
212
213
214
215
216
217
218
219
220
221
def setup_metrics_collection(config: Config, generate_endpoint, logger):
    """Set up metrics collection for vLLM and LMCache metrics.

    In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored:
      1. In-memory: Metric objects in global REGISTRY
      2. On-disk: Metric values in .db files (PROMETHEUS_MULTIPROC_DIR)

    MultiProcessCollector reads from .db files but adding it to REGISTRY can fail
    with "Duplicated timeseries" if PROMETHEUS_MULTIPROC_DIR was set before process
    started (K8s deployments) because metrics are already in REGISTRY.

    Solution: Try adding MultiProcessCollector to REGISTRY. If that fails, use
    separate registry for multiprocess collection and register callbacks to both
    registries to ensure all metrics (vllm, lmcache, dynamo_component) are collected.
222
223
224
225
226

    Auto-label injection:
        Hierarchy labels (dynamo_namespace, dynamo_component, dynamo_endpoint) are automatically
        injected into engine metrics to align Python metrics with Rust auto-labels.
        Additional labels can be provided via inject_labels parameter.
227
228
    """
    if config.engine_args.disable_log_stats is False:
229
230
231
232
233
234
235
236
237
238
239
        # Register the dedicated dynamo_component registry callback
        # IMPORTANT: We do NOT use MultiProcessCollector for DYNAMO_COMPONENT_REGISTRY
        # because our gauges use in-memory values which work fine for single-process
        # and multi-process (each process has its own gauge with dp_rank label).
        # Using MultiProcessCollector would read from .db files which causes stale
        # values to accumulate across test runs.
        register_engine_metrics_callback(
            endpoint=generate_endpoint,
            registry=DYNAMO_COMPONENT_REGISTRY,
        )

240
241
242
243
244
245
246
247
248
        if os.environ.get("PROMETHEUS_MULTIPROC_DIR"):
            try:
                # MultiProcessCollector reads metrics from .db files in PROMETHEUS_MULTIPROC_DIR
                # Adding it to REGISTRY allows collecting both in-memory and .db file metrics
                multiprocess.MultiProcessCollector(REGISTRY)
                logger.debug("Added MultiProcessCollector to global REGISTRY")
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
249
250
251
252
253
                    metric_prefix_filters=["vllm:", "lmcache:"],
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
254
255
256
257
258
259
260
261
262
263
264
                )
            except ValueError as e:
                # Conflict: metrics already in REGISTRY, MultiProcessCollector tries to add same metrics from .db files
                # Solution: Use separate registry that ONLY reads from .db files (no in-memory conflicts)
                logger.debug(
                    f"Could not add MultiProcessCollector to REGISTRY ({e}), using separate registry"
                )
                multiproc_registry = CollectorRegistry()
                multiprocess.MultiProcessCollector(multiproc_registry)

                # Register both registries to collect all metrics
265
                # Global REGISTRY has in-memory metrics (vllm)
266
267
268
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
269
                    metric_prefix_filters=["vllm:"],
270
271
272
273
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
274
275
276
277
278
                )
                # Multiproc registry has .db file metrics (lmcache, possibly vllm duplicates)
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=multiproc_registry,
279
280
281
282
283
                    metric_prefix_filters=["vllm:", "lmcache:"],
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
284
285
286
287
288
289
290
                )
        else:
            # No multiprocess mode
            register_engine_metrics_callback(
                endpoint=generate_endpoint,
                registry=REGISTRY,
                metric_prefix_filters=["vllm:", "lmcache:"],
291
292
293
294
                namespace_name=config.namespace,
                component_name=config.component,
                endpoint_name=config.endpoint,
                model_name=config.model,
295
296
297
            )


Yan Ru Pei's avatar
Yan Ru Pei committed
298
299
def setup_kv_event_publisher(
    config: Config,
300
301
    generate_endpoint: Endpoint,
    vllm_config: VllmConfig,
302
303
    consolidator_enabled: bool = False,
    consolidator_port: Optional[int] = 5558,
304
) -> Optional[KvEventPublisher]:
Yan Ru Pei's avatar
Yan Ru Pei committed
305
    """
Yan Ru Pei's avatar
Yan Ru Pei committed
306
307
    Set up KV event publishers for prefix caching if enabled.
    Creates one publisher per dp_rank since each dp_rank publishes to a different port.
308
309
310
311
312
313
314
    Args:
        config: Worker configuration
        generate_endpoint: Endpoint for worker ID
        vllm_config: vLLM configuration
        consolidator_enabled: If True, subscribe to kv eventconsolidator's ZMQ endpoint
        consolidator_port: Port where kv event consolidator publishes (default: 5558)

Yan Ru Pei's avatar
Yan Ru Pei committed
315
    Returns:
316
        List of KvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
Yan Ru Pei's avatar
Yan Ru Pei committed
317
318
319
320
    """
    if not config.engine_args.enable_prefix_caching:
        return None

321
    # Skip KV event publishing for decode workers
322
    if config.disaggregation_mode == DisaggregationMode.DECODE:
323
324
325
        logger.info("Skipping KV event publisher setup for decode worker")
        return None

326
327
328
    if config.engine_args.kv_events_config is None:
        return None

329
330
331
332
333
334
335
    # Check if kv_cache_events are explicitly disabled
    if not config.engine_args.kv_events_config.enable_kv_cache_events:
        logger.info(
            "KV event publishing skipped: enable_kv_cache_events=False in kv_events_config"
        )
        return None

Yan Ru Pei's avatar
Yan Ru Pei committed
336
337
338
339
340
    # Get data_parallel_size to create publishers for all dp_ranks
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    kv_publishers = []

    for dp_rank in range(data_parallel_size):
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        if consolidator_enabled:
            # TODO: Use different port for each dp_rank once KVBM supports DP
            zmq_endpoint = f"tcp://127.0.0.1:{consolidator_port}"
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to consolidator at {zmq_endpoint}"
            )
        else:
            # Each dp_rank publishes to a different port
            zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
                config.engine_args.kv_events_config.endpoint,
                data_parallel_rank=dp_rank,
            ).replace("*", "127.0.0.1")
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to vLLM at {zmq_endpoint}"
            )
Yan Ru Pei's avatar
Yan Ru Pei committed
356

357
        kv_publisher = KvEventPublisher(
358
            endpoint=generate_endpoint,
Yan Ru Pei's avatar
Yan Ru Pei committed
359
360
            kv_block_size=vllm_config.cache_config.block_size,
            zmq_endpoint=zmq_endpoint,
361
            zmq_topic="",
362
            enable_local_indexer=config.enable_local_indexer,
363
            dp_rank=dp_rank,
Yan Ru Pei's avatar
Yan Ru Pei committed
364
365
        )
        kv_publishers.append(kv_publisher)
Yan Ru Pei's avatar
Yan Ru Pei committed
366

Yan Ru Pei's avatar
Yan Ru Pei committed
367
368
369
        logger.info(
            f"Worker reading KV events for dp_rank={dp_rank} from {zmq_endpoint}"
        )
Yan Ru Pei's avatar
Yan Ru Pei committed
370

Yan Ru Pei's avatar
Yan Ru Pei committed
371
    return kv_publishers if kv_publishers else None
Yan Ru Pei's avatar
Yan Ru Pei committed
372
373


Alec's avatar
Alec committed
374
def setup_vllm_engine(config, stat_logger=None):
375
376
377
    # vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
    # instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
    # ourselves to avoid this and handle cleanup properly.
378
379
380
381
382
383
384
385
    prometheus_temp_dir = None
    if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
        prometheus_temp_dir = tempfile.TemporaryDirectory(prefix="vllm_prometheus_")
        os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_temp_dir.name
        logger.debug(
            f"Created PROMETHEUS_MULTIPROC_DIR at: {os.environ['PROMETHEUS_MULTIPROC_DIR']}"
        )

386
    setup_multiprocess_prometheus()  # call vLLM's library's function to setup multiprocess prometheus
387
388
389
390
    logger.debug(
        f"Prometheus multiproc dir set to: {os.environ.get('PROMETHEUS_MULTIPROC_DIR')}"
    )

391
392
393
394
395
396
397
398
399
400
401
402
403
    # Construct Prometheus gauges AFTER setup_multiprocess_prometheus() so Gauge objects
    # see the correct ValueClass (multiprocess vs in-memory).
    component_gauges = LLMBackendMetrics(
        registry=DYNAMO_COMPONENT_REGISTRY,
        model_name=config.served_model_name or "",
        component_name=config.component or "",
    )

    # If a StatLoggerFactory was provided, give it the gauges so the loggers
    # it creates can publish Prometheus metrics.
    if stat_logger is not None:
        stat_logger.component_gauges = component_gauges

Alec's avatar
Alec committed
404
405
406
407
    os.environ["VLLM_NO_USAGE_STATS"] = "1"  # Avoid internal HTTP requests
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

    engine_args = config.engine_args
408

409
410
411
412
413
    if engine_args.enable_lora:
        if "VLLM_ALLOW_RUNTIME_LORA_UPDATING" not in os.environ:
            os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
        if "VLLM_LORA_MODULES_LOADING_TIMEOUT" not in os.environ:
            os.environ["VLLM_LORA_MODULES_LOADING_TIMEOUT"] = "600"
414
415

    if engine_args.load_format == "gms":
416
        engine_args.worker_cls = "gpu_memory_service.integrations.vllm.worker.GMSWorker"
417

418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    if engine_args.load_format in ("mx-source", "mx-target"):
        try:
            from modelexpress import register_modelexpress_loaders

            # Ensure the ModelExpress server URL env var is set for the model loader
            if config.model_express_url:
                os.environ["MODEL_EXPRESS_URL"] = config.model_express_url
            register_modelexpress_loaders()
            # Use wrapper worker to ensure loaders are registered in spawned worker processes
            engine_args.worker_cls = "modelexpress.vllm_worker.ModelExpressWorker"
        except ImportError as e:
            raise ImportError(
                f"ModelExpress package required for --load-format={engine_args.load_format}. "
                "Install with: pip install modelexpress"
            ) from e

Alec's avatar
Alec committed
434
435
436
437
438
    # Load default sampling params from `generation_config.json`
    default_sampling_params = (
        engine_args.create_model_config().get_diff_sampling_param()
    )

439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    # Configure ec_both mode with DynamoMultimodalEmbeddingCacheConnector.
    # Must happen BEFORE engine setup so vLLM sees ec_transfer_config.
    if (
        not config.route_to_encoder
        and config.multimodal_embedding_cache_capacity_gb > 0
    ):
        from vllm.config import ECTransferConfig

        logger.info(
            "Configuring ec_both mode with DynamoMultimodalEmbeddingCacheConnector "
            "(capacity=%.2f GB)",
            config.multimodal_embedding_cache_capacity_gb,
        )
        instance_id = 0
        engine_id = f"{config.namespace}.{config.component}.backend.{instance_id}"
        engine_args.ec_transfer_config = ECTransferConfig(
            engine_id=engine_id,
            ec_role="ec_both",
            ec_connector="DynamoMultimodalEmbeddingCacheConnector",
            ec_connector_module_path="dynamo.vllm.multimodal_utils.multimodal_embedding_cache_connector",
            ec_connector_extra_config={
                "multimodal_embedding_cache_capacity_gb": config.multimodal_embedding_cache_capacity_gb,
            },
        )
        logger.info("Configured ec_both with engine_id=%s", engine_id)

Alec's avatar
Alec committed
465
466
467
468
    # Taken from build_async_engine_client_from_engine_args()
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

469
470
471
    # Set up consolidator endpoints if KVBM is enabled
    consolidator_endpoints = None
    if config.has_connector("kvbm"):
472
473
474
475
476
477
478
479
480
481
482
483
        try:
            from kvbm.vllm_integration.consolidator_config import (
                get_consolidator_endpoints,
            )

            consolidator_endpoints = get_consolidator_endpoints(vllm_config)
        except Exception as e:
            logger.warning(
                f"KVBM connector is enabled but failed to get consolidator endpoints: {e}. "
                "Continuing without KV event consolidation. "
                "Ensure 'kvbm' package is installed if this feature is needed."
            )
484
485
    vllm_config.consolidator_endpoints = consolidator_endpoints

Alec's avatar
Alec committed
486
487
488
489
    factory = []
    if stat_logger:
        factory.append(stat_logger)

490
491
    # Time engine initialization
    start_time = time.time()
Alec's avatar
Alec committed
492
493
494
495
    engine_client = AsyncLLM.from_vllm_config(
        vllm_config=vllm_config,
        usage_context=usage_context,
        stat_loggers=factory,
496
        enable_log_requests=engine_args.enable_log_requests,
Alec's avatar
Alec committed
497
498
        disable_log_stats=engine_args.disable_log_stats,
    )
499
500
501
502
    load_time = time.time() - start_time

    # Record model load time
    component_gauges.set_model_load_time(load_time)
503
504

    logger.info(f"VllmWorker for {config.served_model_name} has been initialized")
505

506
507
508
509
510
511
512
    return (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
        component_gauges,
    )
Alec's avatar
Alec committed
513
514


515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
async def register_vllm_model(
    model_input: ModelInput,
    model_type: ModelType,
    generate_endpoint,
    config: Config,
    engine_client: AsyncLLM,
    vllm_config,
):
    """
    Helper function to register a vLLM model with runtime configuration.

    Args:
        model_input: Input type for the model (e.g., ModelInput.Tokens)
        model_type: Type of model (e.g., ModelType.Chat, ModelType.Prefill)
        generate_endpoint: Endpoint to register
        config: Configuration object
        engine_client: vLLM engine client
        vllm_config: vLLM configuration
    """
    runtime_config = ModelRuntimeConfig()

    # Get runtime configuration from vLLM engine
    logging.info(
        f"Getting engine runtime configuration metadata from vLLM engine for {model_type}..."
    )
    runtime_values = get_engine_cache_info(engine_client)
    runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
    runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
    runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
544
545
    # Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
    runtime_config.enable_local_indexer = (
546
547
        config.enable_local_indexer
        and config.disaggregation_mode != DisaggregationMode.DECODE
548
    )
549
550
551

    # Add tool/reasoning parsers for decode models
    if model_type != ModelType.Prefill:
552
553
        runtime_config.tool_call_parser = config.dyn_tool_call_parser
        runtime_config.reasoning_parser = config.dyn_reasoning_parser
554
555
556
557
558

    # Get data_parallel_size from vllm_config (defaults to 1)
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    runtime_config.data_parallel_size = data_parallel_size

559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
    # Configure media decoder for frontend image decoding when enabled
    # This enables frontend to decode images and transfer via NIXL RDMA
    media_decoder = None
    media_fetcher = None
    if config.frontend_decoding:
        if not MEDIA_DECODER_AVAILABLE:
            raise RuntimeError(
                "--frontend-decoding requires MediaDecoder support. "
                "Ensure dynamo.llm module includes MediaDecoder and MediaFetcher."
            )
        media_decoder = MediaDecoder()
        media_decoder.enable_image({"limits": {"max_alloc": 128 * 1024 * 1024}})
        # media_decoder.enable_video({})

        media_fetcher = MediaFetcher()
        media_fetcher.timeout_ms(30000)
575
        media_fetcher.allow_direct_port(True)
576

577
    await register_model(
578
579
580
581
582
        model_input,
        model_type,
        generate_endpoint,
        config.model,
        config.served_model_name,
583
        kv_cache_block_size=runtime_values["block_size"],
584
585
        runtime_config=runtime_config,
        custom_template_path=config.custom_jinja_template,
586
587
        media_decoder=media_decoder,
        media_fetcher=media_fetcher,
588
589
590
    )


591
async def init_prefill(
592
593
594
595
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    pre_created_engine=None,
596
):
Alec's avatar
Alec committed
597
598
599
    """
    Instantiate and serve
    """
600
601
602
    generate_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.{config.endpoint}"
    )
603
604
605
    clear_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.clear_kv_blocks"
    )
Alec's avatar
Alec committed
606

607
608
609
610
611
612
613
    # Use pre-created engine if provided (checkpoint mode), otherwise create new
    if pre_created_engine is not None:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
614
            _component_gauges,
615
616
617
618
619
620
621
        ) = pre_created_engine
    else:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
622
            _component_gauges,
623
        ) = setup_vllm_engine(config)
Alec's avatar
Alec committed
624

625
    handler = PrefillWorkerHandler(
626
627
628
629
        runtime,
        engine_client,
        default_sampling_params,
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
630
        enable_multimodal=config.enable_multimodal,
631
632
        generate_endpoint=generate_endpoint,
        config=config,
633
        use_vllm_tokenizer=config.use_vllm_tokenizer,
634
        shutdown_event=shutdown_event,
635
        enable_frontend_decoding=config.frontend_decoding,
636
    )
637
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
638

639
640
641
642
643
644
645
646
647
648
649
650
651
652
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

Yan Ru Pei's avatar
Yan Ru Pei committed
653
    # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
654
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
655
    kv_publishers = setup_kv_event_publisher(
656
657
658
659
660
        config,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
661
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
662
663
    if kv_publishers:
        handler.kv_publishers = kv_publishers
664

665
    setup_metrics_collection(config, generate_endpoint, logger)
666

667
    # Register sleep/wake_up engine routes
668
    runtime.register_engine_route("sleep", handler.sleep)
669
670
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
671

672
673
674
675
    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return
676
    shutdown_endpoints[:] = [generate_endpoint, clear_endpoint]
677

678
    # Register prefill model with ModelType.Prefill
679
680
681
682
683
684
685
686
687
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
    await register_vllm_model(
        model_input,
        ModelType.Prefill,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
    )
688

689
690
691
    health_check_payload = VllmPrefillHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
692

Alec's avatar
Alec committed
693
    try:
694
        logger.debug("Starting serve_endpoint for prefill worker")
Alec's avatar
Alec committed
695
        await asyncio.gather(
696
697
698
699
            # for prefill, we want to shutdown the engine after all prefill requests are finished because
            #     (temp reason): we don't support re-routing prefill requests
            #     (long-term reason): prefill engine should pull from a global queue so there is
            #                         only a few in-flight requests that can be quickly finished
700
701
702
            generate_endpoint.serve_endpoint(
                handler.generate,
                graceful_shutdown=True,
703
                # In practice config.served_model_name is always set, but mypy needs the "or" here.
704
705
706
707
708
709
710
711
712
713
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
714
                health_check_payload=health_check_payload,
715
716
            ),
            clear_endpoint.serve_endpoint(
717
                handler.clear_kv_blocks,
718
719
720
721
                metrics_labels=[
                    (prometheus_names.labels.MODEL, config.served_model_name),
                    (prometheus_names.labels.MODEL_NAME, config.served_model_name),
                ],
722
            ),
Alec's avatar
Alec committed
723
        )
724
        logger.debug("serve_endpoint completed for prefill worker")
Alec's avatar
Alec committed
725
726
727
728
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
729
        logger.debug("Cleaning up prefill worker")
Alec's avatar
Alec committed
730
731
732
        handler.cleanup()


733
async def init(
734
735
736
737
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    pre_created_engine=None,
738
):
Alec's avatar
Alec committed
739
740
741
742
    """
    Instantiate and serve
    """

743
744
745
    generate_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.{config.endpoint}"
    )
746
747
748
    clear_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.clear_kv_blocks"
    )
749

750
751
752
753
754
    shutdown_endpoints[:] = [
        generate_endpoint,
        clear_endpoint,
    ]

755
756
    lora_enabled = config.engine_args.enable_lora
    if lora_enabled:
757
758
759
760
761
762
763
764
765
        load_lora_endpoint = runtime.endpoint(
            f"{config.namespace}.{config.component}.load_lora"
        )
        unload_lora_endpoint = runtime.endpoint(
            f"{config.namespace}.{config.component}.unload_lora"
        )
        list_loras_endpoint = runtime.endpoint(
            f"{config.namespace}.{config.component}.list_loras"
        )
Alec's avatar
Alec committed
766

767
768
769
770
771
772
773
774
        shutdown_endpoints.extend(
            [
                load_lora_endpoint,
                unload_lora_endpoint,
                list_loras_endpoint,
            ]
        )

775
776
777
778
779
780
781
    # Use pre-created engine if provided (checkpoint mode), otherwise create new
    if pre_created_engine is not None:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
782
            component_gauges,
783
        ) = pre_created_engine
784
785
        # Factory is created after unpack so component_gauges is available
        factory = StatLoggerFactory(
786
            endpoint=generate_endpoint,
787
788
789
            component_gauges=component_gauges,
            dp_rank=config.engine_args.data_parallel_rank or 0,
        )
790
    else:
791
792
793
794
        # Factory is created without component_gauges; setup_vllm_engine() will
        # create the gauges after setup_multiprocess_prometheus() and set them
        # on the factory before vLLM calls create_stat_logger().
        factory = StatLoggerFactory(
795
            endpoint=generate_endpoint,
796
797
            dp_rank=config.engine_args.data_parallel_rank or 0,
        )
798
799
800
801
802
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
803
            component_gauges,
804
        ) = setup_vllm_engine(config, factory)
Alec's avatar
Alec committed
805

806
    # TODO Hack to get data, move this to registering in TBD
Alec's avatar
Alec committed
807
808
809
810
    factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
    factory.init_publish()

    handler = DecodeWorkerHandler(
811
812
813
        runtime,
        engine_client,
        default_sampling_params,
814
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
815
        enable_multimodal=config.enable_multimodal,
816
817
        generate_endpoint=generate_endpoint,
        config=config,
818
        use_vllm_tokenizer=config.use_vllm_tokenizer,
819
        shutdown_event=shutdown_event,
820
        enable_frontend_decoding=config.frontend_decoding,
Alec's avatar
Alec committed
821
    )
822
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
823

824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

    # Set up KV event publisher for prefix caching if enabled
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
840
    kv_publishers = setup_kv_event_publisher(
841
842
843
844
845
        config,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
Yan Ru Pei's avatar
Yan Ru Pei committed
846
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
847
848
    if kv_publishers:
        handler.kv_publishers = kv_publishers
849

850
    setup_metrics_collection(config, generate_endpoint, logger)
851

852
    # Register sleep/wake_up engine routes
853
    runtime.register_engine_route("sleep", handler.sleep)
854
855
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
856

857
858
859
860
    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return
861

862
863
864
    # Parse endpoint types from --endpoint-types flag
    model_type = parse_endpoint_types(config.endpoint_types)
    logger.info(f"Registering model with endpoint types: {config.endpoint_types}")
865

866
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
867

868
    # Warn if custom template provided but chat endpoint not enabled
869
    if config.custom_jinja_template and "chat" not in config.endpoint_types:
870
871
872
        logger.warning(
            "Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
            "The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
873
874
        )

875
876
877
878
879
880
881
882
883
    await register_vllm_model(
        model_input,
        model_type,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
    )

884
885
886
    health_check_payload = VllmHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
887

Alec's avatar
Alec committed
888
    try:
889
        logger.debug("Starting serve_endpoint for decode worker")
890
891
892
893
894
895
896
897
898
899
900
901
902

        model_metrics_labels = [
            (
                prometheus_names.labels.MODEL,
                config.served_model_name or config.model,
            ),
            (
                prometheus_names.labels.MODEL_NAME,
                config.served_model_name or config.model,
            ),
        ]

        serve_tasks = [
903
904
            # for decode, we want to transfer the in-flight requests to other decode engines,
            # because waiting them to finish can take a long time for long OSLs
905
906
            generate_endpoint.serve_endpoint(
                handler.generate,
907
                graceful_shutdown=True,
908
                metrics_labels=model_metrics_labels,
909
                health_check_payload=health_check_payload,
910
911
            ),
            clear_endpoint.serve_endpoint(
912
                handler.clear_kv_blocks,
913
                metrics_labels=model_metrics_labels,
914
            ),
915
916
917
918
919
920
921
922
        ]

        if lora_enabled:
            serve_tasks.extend(
                [
                    load_lora_endpoint.serve_endpoint(
                        handler.load_lora,
                        metrics_labels=model_metrics_labels,
923
                    ),
924
925
926
                    unload_lora_endpoint.serve_endpoint(
                        handler.unload_lora,
                        metrics_labels=model_metrics_labels,
927
                    ),
928
929
930
                    list_loras_endpoint.serve_endpoint(
                        handler.list_loras,
                        metrics_labels=model_metrics_labels,
931
                    ),
932
933
934
935
                ]
            )

        await asyncio.gather(*serve_tasks)
936
        logger.debug("serve_endpoint completed for decode worker")
Alec's avatar
Alec committed
937
938
939
940
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
941
        logger.debug("Cleaning up decode worker")
Alec's avatar
Alec committed
942
943
944
945
        # Cleanup background tasks
        handler.cleanup()


946
947
948
949
950
951
952
def get_engine_cache_info(engine: AsyncLLM):
    """Retrieve cache configuration information from [`AsyncLLM`] engine."""

    try:
        # Get values directly from vllm_config instead of collective_rpc
        cache_values = {
            "num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks,
953
            "block_size": engine.vllm_config.cache_config.block_size,
954
955
956
957
958
959
960
961
962
963
964
        }

        scheduler_values = {
            "max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs,
            "max_num_batched_tokens": engine.vllm_config.scheduler_config.max_num_batched_tokens,
        }

        logging.info(f"Cache config values: {cache_values}")
        logging.info(f"Scheduler config values: {scheduler_values}")
        return {
            "num_gpu_blocks": cache_values["num_gpu_blocks"],
965
            "block_size": cache_values["block_size"],
966
967
968
969
970
971
972
973
            "max_num_seqs": scheduler_values["max_num_seqs"],
            "max_num_batched_tokens": scheduler_values["max_num_batched_tokens"],
        }
    except Exception as e:
        logging.error(f"Failed to get configuration values from vLLM config: {e}")
        raise


974
975
976
async def init_omni(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
977
    """Initialize Omni worker for multi-stage pipeline generation using vLLM-Omni.
978

979
980
    Supports text-to-text, text-to-image, and text-to-video generation
    through a single unified OmniHandler.
981
982
983
    """
    from dynamo.vllm.omni import OmniHandler

984
985
986
    generate_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.{config.endpoint}"
    )
987

988
    shutdown_endpoints[:] = [generate_endpoint]
989

990
991
992
993
994
    # Initialize media filesystem for storing generated images/videos
    media_fs = (
        get_fs(config.media_output_fs_url) if config.media_output_fs_url else None
    )

995
    # Initialize unified OmniHandler
996
997
998
    handler = OmniHandler(
        runtime=runtime,
        config=config,
999
        default_sampling_params={},
1000
        shutdown_event=shutdown_event,
1001
1002
        media_output_fs=media_fs,
        media_output_http_url=config.media_output_http_url,
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
    )

    logger.info(f"Omni worker initialized for model: {config.model}")

    # Set up metrics collection for vLLM and LMCache metrics
    setup_metrics_collection(config, generate_endpoint, logger)

    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return

    # TODO: extend for multi-stage pipelines
1016
1017
1018
1019
    model_type = get_output_modalities(config.output_modalities, config.model)
    if model_type is None:
        # Default to Images
        model_type = ModelType.Images
1020
    await register_model(
1021
        ModelInput.Text,
1022
        model_type,
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
    )

    logger.info("Starting to serve Omni worker endpoint...")

    health_check_payload = (
        await VllmOmniHealthCheckPayload.create(handler.engine_client)
    ).to_dict()

    try:
        await generate_endpoint.serve_endpoint(
            handler.generate,
            graceful_shutdown=True,
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
            metrics_labels=[
                (
                    prometheus_names.labels.MODEL,
                    config.served_model_name or config.model,
                ),
                (
                    prometheus_names.labels.MODEL_NAME,
                    config.served_model_name or config.model,
                ),
            ],
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
            health_check_payload=health_check_payload,
        )
    except Exception as e:
        logger.error(f"Failed to serve Omni endpoint: {e}")
        raise
    finally:
        logger.debug("Cleaning up Omni worker")
        handler.cleanup()


Alec's avatar
Alec committed
1059
1060
1061
1062
def main():
    uvloop.run(worker())


Alec's avatar
Alec committed
1063
if __name__ == "__main__":
Alec's avatar
Alec committed
1064
    main()