main.py 37 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Alec's avatar
Alec committed
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging
import os
7
import tempfile
8
import time
9
from typing import Optional
Alec's avatar
Alec committed
10
11

import uvloop
12
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
Alec's avatar
Alec committed
13
14
15
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
16
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
Alec's avatar
Alec committed
17

18
from dynamo import prometheus_names
19
from dynamo.common.config_dump import dump_config
20
from dynamo.common.storage import get_fs
21
from dynamo.common.utils.endpoint_types import parse_endpoint_types
22
from dynamo.common.utils.output_modalities import get_output_modalities
23
24
25
26
from dynamo.common.utils.prometheus import (
    LLMBackendMetrics,
    register_engine_metrics_callback,
)
27
from dynamo.common.utils.runtime import create_runtime
Alec's avatar
Alec committed
28
from dynamo.llm import (
29
    KvEventPublisher,
30
    ModelInput,
31
    ModelRuntimeConfig,
Alec's avatar
Alec committed
32
    ModelType,
33
34
    fetch_model,
    register_model,
Alec's avatar
Alec committed
35
)
36
37
38
39
40
41
42
43
44
45
46

# Optional imports for frontend decoding support
try:
    from dynamo.llm import MediaDecoder, MediaFetcher

    MEDIA_DECODER_AVAILABLE = True
except ImportError:
    MediaDecoder = None
    MediaFetcher = None
    MEDIA_DECODER_AVAILABLE = False

47
from dynamo.runtime import DistributedRuntime
Alec's avatar
Alec committed
48
from dynamo.runtime.logging import configure_dynamo_logging
49
from dynamo.vllm.worker_factory import WorkerFactory
Alec's avatar
Alec committed
50

51
from .args import Config, parse_args
52
from .checkpoint_restore import get_checkpoint_config
Alec's avatar
Alec committed
53
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
54
55
56
57
58
from .health_check import (
    VllmHealthCheckPayload,
    VllmOmniHealthCheckPayload,
    VllmPrefillHealthCheckPayload,
)
59
from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
Alec's avatar
Alec committed
60

Alec's avatar
Alec committed
61
62
configure_dynamo_logging()
logger = logging.getLogger(__name__)
63
CHECKPOINT_SLEEP_MODE_LEVEL = 1
Alec's avatar
Alec committed
64
65


66
67
68
69
70
71
72
73
74
75
76
77
78
async def _handle_non_leader_node(dp_rank: int) -> None:
    """
    Handle non-leader node (data_parallel_rank >= 1) in multi-node deployments.
    Non-leader nodes run vLLM workers but don't serve Dynamo endpoints.
    """
    logger.info(
        f"Non-leader node detected (data_parallel_rank={dp_rank}). "
        "Skipping endpoint serving."
    )
    # Wait indefinitely - process terminated via signal handlers
    await asyncio.Event().wait()


79
async def graceful_shutdown(runtime, shutdown_event):
80
    """
81
82
83
84
    Shutdown dynamo distributed runtime.
    The endpoints will be immediately invalidated so no new requests will be accepted.
    For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
    For endpoints served with graceful_shutdown=False, the serving function will return immediately.
85
    """
86
87
88
89
    logging.info("Received shutdown signal, shutting down DistributedRuntime")
    shutdown_event.set()
    runtime.shutdown()
    logging.info("DistributedRuntime shutdown complete")
90
91


92
async def worker():
Alec's avatar
Alec committed
93
94
    config = parse_args()

95
96
97
98
99
100
101
    dump_config(config.dump_config_to, config)

    # Name the model. Use either the full path (vllm and sglang do the same),
    # or the HF name (e.g. "Qwen/Qwen3-0.6B"), depending on cmd line params.
    if not config.served_model_name:
        config.served_model_name = config.engine_args.served_model_name = config.model

102
    # Check checkpoint mode and validate env vars EARLY (fail fast if misconfigured)
103
104
    early_exit, checkpoint_cfg = get_checkpoint_config()
    if early_exit:
105
        return
106
107
108
109

    # Download the model if necessary using modelexpress.
    # We want it on disk before we start vllm to avoid downloading from HuggingFace.
    #
110
    # We don't set `config.engine_args.model` to the local path fetch_model returns
111
112
113
114
115
116
    # because vllm will send that name to its Ray pipeline-parallel workers, which
    # may not have the local path.
    # vllm will attempt to download the model again, but find it in the HF cache.
    # For non-HF models use a path instead of an HF name, and ensure all workers have
    # that path (ideally via a shared folder).
    if not os.path.exists(config.model):
117
        await fetch_model(config.model)
118
119
120
121

    # CHECKPOINT MODE: Load engine BEFORE runtime creation
    # This allows checkpointing GPU state before runtime connections are established
    pre_created_engine = None
122
    if checkpoint_cfg is not None:
123
        logger.info("Checkpoint mode enabled (watcher-driven signals)")
124

125
126
127
        # Checkpoint mode requires sleep mode — enable before engine init
        config.engine_args.enable_sleep_mode = True

128
129
130
        pre_created_engine = setup_vllm_engine(config)
        engine_client = pre_created_engine[0]

131
132
133
        if not await checkpoint_cfg.run_lifecycle(
            engine_client, CHECKPOINT_SLEEP_MODE_LEVEL
        ):
134
            return
135

136
    shutdown_event = asyncio.Event()
137
    runtime, _ = create_runtime(
138
        discovery_backend=config.discovery_backend,
139
140
141
142
        request_plane=config.request_plane,
        event_plane=config.event_plane,
        use_kv_events=config.use_kv_events,
        shutdown_event=shutdown_event,
143
144
    )

145
    # Route to appropriate initialization based on config flags
146
147
148
149
150
151
152
153
    if WorkerFactory.handles(config):
        # Create worker factory with setup functions
        factory = WorkerFactory(
            setup_vllm_engine_fn=setup_vllm_engine,
            setup_kv_event_publisher_fn=setup_kv_event_publisher,
            register_vllm_model_fn=register_vllm_model,
        )
        await factory.create(
154
155
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
156
        logger.debug("multimodal worker completed")
157
158
159
    elif config.omni:
        await init_omni(runtime, config, shutdown_event)
        logger.debug("init_omni completed")
160
    elif config.is_prefill_worker:
161
162
163
        await init_prefill(
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
164
        logger.debug("init_prefill completed")
Alec's avatar
Alec committed
165
    else:
166
167
168
        await init(
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
169
170
171
        logger.debug("init completed")

    logger.debug("Worker function completed, exiting...")
Alec's avatar
Alec committed
172
173


174
175
176
177
178
179
180
181
182
183
184
185
186
187
def setup_metrics_collection(config: Config, generate_endpoint, logger):
    """Set up metrics collection for vLLM and LMCache metrics.

    In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored:
      1. In-memory: Metric objects in global REGISTRY
      2. On-disk: Metric values in .db files (PROMETHEUS_MULTIPROC_DIR)

    MultiProcessCollector reads from .db files but adding it to REGISTRY can fail
    with "Duplicated timeseries" if PROMETHEUS_MULTIPROC_DIR was set before process
    started (K8s deployments) because metrics are already in REGISTRY.

    Solution: Try adding MultiProcessCollector to REGISTRY. If that fails, use
    separate registry for multiprocess collection and register callbacks to both
    registries to ensure all metrics (vllm, lmcache, dynamo_component) are collected.
188
189
190
191
192

    Auto-label injection:
        Hierarchy labels (dynamo_namespace, dynamo_component, dynamo_endpoint) are automatically
        injected into engine metrics to align Python metrics with Rust auto-labels.
        Additional labels can be provided via inject_labels parameter.
193
194
    """
    if config.engine_args.disable_log_stats is False:
195
196
197
198
199
200
201
202
203
204
205
        # Register the dedicated dynamo_component registry callback
        # IMPORTANT: We do NOT use MultiProcessCollector for DYNAMO_COMPONENT_REGISTRY
        # because our gauges use in-memory values which work fine for single-process
        # and multi-process (each process has its own gauge with dp_rank label).
        # Using MultiProcessCollector would read from .db files which causes stale
        # values to accumulate across test runs.
        register_engine_metrics_callback(
            endpoint=generate_endpoint,
            registry=DYNAMO_COMPONENT_REGISTRY,
        )

206
207
208
209
210
211
212
213
214
        if os.environ.get("PROMETHEUS_MULTIPROC_DIR"):
            try:
                # MultiProcessCollector reads metrics from .db files in PROMETHEUS_MULTIPROC_DIR
                # Adding it to REGISTRY allows collecting both in-memory and .db file metrics
                multiprocess.MultiProcessCollector(REGISTRY)
                logger.debug("Added MultiProcessCollector to global REGISTRY")
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
215
216
217
218
219
                    metric_prefix_filters=["vllm:", "lmcache:"],
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
220
221
222
223
224
225
226
227
228
229
230
                )
            except ValueError as e:
                # Conflict: metrics already in REGISTRY, MultiProcessCollector tries to add same metrics from .db files
                # Solution: Use separate registry that ONLY reads from .db files (no in-memory conflicts)
                logger.debug(
                    f"Could not add MultiProcessCollector to REGISTRY ({e}), using separate registry"
                )
                multiproc_registry = CollectorRegistry()
                multiprocess.MultiProcessCollector(multiproc_registry)

                # Register both registries to collect all metrics
231
                # Global REGISTRY has in-memory metrics (vllm)
232
233
234
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
235
                    metric_prefix_filters=["vllm:"],
236
237
238
239
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
240
241
242
243
244
                )
                # Multiproc registry has .db file metrics (lmcache, possibly vllm duplicates)
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=multiproc_registry,
245
246
247
248
249
                    metric_prefix_filters=["vllm:", "lmcache:"],
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
250
251
252
253
254
255
256
                )
        else:
            # No multiprocess mode
            register_engine_metrics_callback(
                endpoint=generate_endpoint,
                registry=REGISTRY,
                metric_prefix_filters=["vllm:", "lmcache:"],
257
258
259
260
                namespace_name=config.namespace,
                component_name=config.component,
                endpoint_name=config.endpoint,
                model_name=config.model,
261
262
263
            )


Yan Ru Pei's avatar
Yan Ru Pei committed
264
265
266
267
268
def setup_kv_event_publisher(
    config: Config,
    component,
    generate_endpoint,
    vllm_config,
269
270
    consolidator_enabled: bool = False,
    consolidator_port: Optional[int] = 5558,
271
) -> Optional[KvEventPublisher]:
Yan Ru Pei's avatar
Yan Ru Pei committed
272
    """
Yan Ru Pei's avatar
Yan Ru Pei committed
273
274
    Set up KV event publishers for prefix caching if enabled.
    Creates one publisher per dp_rank since each dp_rank publishes to a different port.
275
276
277
278
279
280
281
282
    Args:
        config: Worker configuration
        component: Component for runtime integration
        generate_endpoint: Endpoint for worker ID
        vllm_config: vLLM configuration
        consolidator_enabled: If True, subscribe to kv eventconsolidator's ZMQ endpoint
        consolidator_port: Port where kv event consolidator publishes (default: 5558)

Yan Ru Pei's avatar
Yan Ru Pei committed
283
    Returns:
284
        List of KvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
Yan Ru Pei's avatar
Yan Ru Pei committed
285
286
287
288
    """
    if not config.engine_args.enable_prefix_caching:
        return None

289
290
291
292
293
    # Skip KV event publishing for decode workers
    if config.is_decode_worker:
        logger.info("Skipping KV event publisher setup for decode worker")
        return None

294
295
296
    if config.engine_args.kv_events_config is None:
        return None

297
298
299
300
301
302
303
    # Check if kv_cache_events are explicitly disabled
    if not config.engine_args.kv_events_config.enable_kv_cache_events:
        logger.info(
            "KV event publishing skipped: enable_kv_cache_events=False in kv_events_config"
        )
        return None

Yan Ru Pei's avatar
Yan Ru Pei committed
304
305
306
307
308
    # Get data_parallel_size to create publishers for all dp_ranks
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    kv_publishers = []

    for dp_rank in range(data_parallel_size):
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        if consolidator_enabled:
            # TODO: Use different port for each dp_rank once KVBM supports DP
            zmq_endpoint = f"tcp://127.0.0.1:{consolidator_port}"
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to consolidator at {zmq_endpoint}"
            )
        else:
            # Each dp_rank publishes to a different port
            zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
                config.engine_args.kv_events_config.endpoint,
                data_parallel_rank=dp_rank,
            ).replace("*", "127.0.0.1")
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to vLLM at {zmq_endpoint}"
            )
Yan Ru Pei's avatar
Yan Ru Pei committed
324

325
326
        kv_publisher = KvEventPublisher(
            component=component,
Yan Ru Pei's avatar
Yan Ru Pei committed
327
328
            kv_block_size=vllm_config.cache_config.block_size,
            zmq_endpoint=zmq_endpoint,
329
            zmq_topic="",
330
            enable_local_indexer=config.enable_local_indexer,
331
            dp_rank=dp_rank,
Yan Ru Pei's avatar
Yan Ru Pei committed
332
333
        )
        kv_publishers.append(kv_publisher)
Yan Ru Pei's avatar
Yan Ru Pei committed
334

Yan Ru Pei's avatar
Yan Ru Pei committed
335
336
337
        logger.info(
            f"Worker reading KV events for dp_rank={dp_rank} from {zmq_endpoint}"
        )
Yan Ru Pei's avatar
Yan Ru Pei committed
338

Yan Ru Pei's avatar
Yan Ru Pei committed
339
    return kv_publishers if kv_publishers else None
Yan Ru Pei's avatar
Yan Ru Pei committed
340
341


Alec's avatar
Alec committed
342
def setup_vllm_engine(config, stat_logger=None):
343
344
345
    # vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
    # instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
    # ourselves to avoid this and handle cleanup properly.
346
347
348
349
350
351
352
353
    prometheus_temp_dir = None
    if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
        prometheus_temp_dir = tempfile.TemporaryDirectory(prefix="vllm_prometheus_")
        os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_temp_dir.name
        logger.debug(
            f"Created PROMETHEUS_MULTIPROC_DIR at: {os.environ['PROMETHEUS_MULTIPROC_DIR']}"
        )

354
    setup_multiprocess_prometheus()  # call vLLM's library's function to setup multiprocess prometheus
355
356
357
358
    logger.debug(
        f"Prometheus multiproc dir set to: {os.environ.get('PROMETHEUS_MULTIPROC_DIR')}"
    )

359
360
361
362
363
364
365
366
367
368
369
370
371
    # Construct Prometheus gauges AFTER setup_multiprocess_prometheus() so Gauge objects
    # see the correct ValueClass (multiprocess vs in-memory).
    component_gauges = LLMBackendMetrics(
        registry=DYNAMO_COMPONENT_REGISTRY,
        model_name=config.served_model_name or "",
        component_name=config.component or "",
    )

    # If a StatLoggerFactory was provided, give it the gauges so the loggers
    # it creates can publish Prometheus metrics.
    if stat_logger is not None:
        stat_logger.component_gauges = component_gauges

Alec's avatar
Alec committed
372
373
374
375
    os.environ["VLLM_NO_USAGE_STATS"] = "1"  # Avoid internal HTTP requests
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

    engine_args = config.engine_args
376

377
378
379
380
381
    if engine_args.enable_lora:
        if "VLLM_ALLOW_RUNTIME_LORA_UPDATING" not in os.environ:
            os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
        if "VLLM_LORA_MODULES_LOADING_TIMEOUT" not in os.environ:
            os.environ["VLLM_LORA_MODULES_LOADING_TIMEOUT"] = "600"
382
383

    if engine_args.load_format == "gms":
384
        engine_args.worker_cls = "gpu_memory_service.integrations.vllm.worker.GMSWorker"
385

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    if engine_args.load_format in ("mx-source", "mx-target"):
        try:
            from modelexpress import register_modelexpress_loaders

            # Ensure the ModelExpress server URL env var is set for the model loader
            if config.model_express_url:
                os.environ["MODEL_EXPRESS_URL"] = config.model_express_url
            register_modelexpress_loaders()
            # Use wrapper worker to ensure loaders are registered in spawned worker processes
            engine_args.worker_cls = "modelexpress.vllm_worker.ModelExpressWorker"
        except ImportError as e:
            raise ImportError(
                f"ModelExpress package required for --load-format={engine_args.load_format}. "
                "Install with: pip install modelexpress"
            ) from e

Alec's avatar
Alec committed
402
403
404
405
406
407
408
409
410
    # Load default sampling params from `generation_config.json`
    default_sampling_params = (
        engine_args.create_model_config().get_diff_sampling_param()
    )

    # Taken from build_async_engine_client_from_engine_args()
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

411
412
413
    # Set up consolidator endpoints if KVBM is enabled
    consolidator_endpoints = None
    if config.has_connector("kvbm"):
414
415
416
417
418
419
420
421
422
423
424
425
        try:
            from kvbm.vllm_integration.consolidator_config import (
                get_consolidator_endpoints,
            )

            consolidator_endpoints = get_consolidator_endpoints(vllm_config)
        except Exception as e:
            logger.warning(
                f"KVBM connector is enabled but failed to get consolidator endpoints: {e}. "
                "Continuing without KV event consolidation. "
                "Ensure 'kvbm' package is installed if this feature is needed."
            )
426
427
    vllm_config.consolidator_endpoints = consolidator_endpoints

Alec's avatar
Alec committed
428
429
430
431
    factory = []
    if stat_logger:
        factory.append(stat_logger)

432
433
    # Time engine initialization
    start_time = time.time()
Alec's avatar
Alec committed
434
435
436
437
    engine_client = AsyncLLM.from_vllm_config(
        vllm_config=vllm_config,
        usage_context=usage_context,
        stat_loggers=factory,
438
        enable_log_requests=engine_args.enable_log_requests,
Alec's avatar
Alec committed
439
440
        disable_log_stats=engine_args.disable_log_stats,
    )
441
442
443
444
    load_time = time.time() - start_time

    # Record model load time
    component_gauges.set_model_load_time(load_time)
445
446

    logger.info(f"VllmWorker for {config.served_model_name} has been initialized")
447

448
449
450
451
452
453
454
    return (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
        component_gauges,
    )
Alec's avatar
Alec committed
455
456


457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
async def register_vllm_model(
    model_input: ModelInput,
    model_type: ModelType,
    generate_endpoint,
    config: Config,
    engine_client: AsyncLLM,
    vllm_config,
):
    """
    Helper function to register a vLLM model with runtime configuration.

    Args:
        model_input: Input type for the model (e.g., ModelInput.Tokens)
        model_type: Type of model (e.g., ModelType.Chat, ModelType.Prefill)
        generate_endpoint: Endpoint to register
        config: Configuration object
        engine_client: vLLM engine client
        vllm_config: vLLM configuration
    """
    runtime_config = ModelRuntimeConfig()

    # Get runtime configuration from vLLM engine
    logging.info(
        f"Getting engine runtime configuration metadata from vLLM engine for {model_type}..."
    )
    runtime_values = get_engine_cache_info(engine_client)
    runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
    runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
    runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
486
487
488
489
    # Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
    runtime_config.enable_local_indexer = (
        config.enable_local_indexer and not config.is_decode_worker
    )
490
491
492

    # Add tool/reasoning parsers for decode models
    if model_type != ModelType.Prefill:
493
494
        runtime_config.tool_call_parser = config.dyn_tool_call_parser
        runtime_config.reasoning_parser = config.dyn_reasoning_parser
495
496
497
498
499

    # Get data_parallel_size from vllm_config (defaults to 1)
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    runtime_config.data_parallel_size = data_parallel_size

500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
    # Configure media decoder for frontend image decoding when enabled
    # This enables frontend to decode images and transfer via NIXL RDMA
    media_decoder = None
    media_fetcher = None
    if config.frontend_decoding:
        if not MEDIA_DECODER_AVAILABLE:
            raise RuntimeError(
                "--frontend-decoding requires MediaDecoder support. "
                "Ensure dynamo.llm module includes MediaDecoder and MediaFetcher."
            )
        media_decoder = MediaDecoder()
        media_decoder.enable_image({"limits": {"max_alloc": 128 * 1024 * 1024}})
        # media_decoder.enable_video({})

        media_fetcher = MediaFetcher()
        media_fetcher.timeout_ms(30000)
516
        media_fetcher.allow_direct_port(True)
517

518
    await register_model(
519
520
521
522
523
        model_input,
        model_type,
        generate_endpoint,
        config.model,
        config.served_model_name,
524
        kv_cache_block_size=runtime_values["block_size"],
525
526
        runtime_config=runtime_config,
        custom_template_path=config.custom_jinja_template,
527
528
        media_decoder=media_decoder,
        media_fetcher=media_fetcher,
529
530
531
    )


532
async def init_prefill(
533
534
535
536
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    pre_created_engine=None,
537
):
Alec's avatar
Alec committed
538
539
540
    """
    Instantiate and serve
    """
541
542
543
544
    generate_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.{config.endpoint}"
    )
    component = generate_endpoint.component()
Alec's avatar
Alec committed
545
546
    clear_endpoint = component.endpoint("clear_kv_blocks")

547
548
549
550
551
552
553
    # Use pre-created engine if provided (checkpoint mode), otherwise create new
    if pre_created_engine is not None:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
554
            _component_gauges,
555
556
557
558
559
560
561
        ) = pre_created_engine
    else:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
562
            _component_gauges,
563
        ) = setup_vllm_engine(config)
Alec's avatar
Alec committed
564

565
    handler = PrefillWorkerHandler(
566
567
568
569
570
        runtime,
        component,
        engine_client,
        default_sampling_params,
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
571
        enable_multimodal=config.enable_multimodal,
572
573
        generate_endpoint=generate_endpoint,
        config=config,
574
        use_vllm_tokenizer=config.use_vllm_tokenizer,
575
        shutdown_event=shutdown_event,
576
        enable_frontend_decoding=config.frontend_decoding,
577
    )
578
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
579

580
581
582
583
584
585
586
587
588
589
590
591
592
593
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

Yan Ru Pei's avatar
Yan Ru Pei committed
594
    # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
595
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
596
    kv_publishers = setup_kv_event_publisher(
597
598
599
600
601
602
        config,
        component,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
603
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
604
605
    if kv_publishers:
        handler.kv_publishers = kv_publishers
606

607
    setup_metrics_collection(config, generate_endpoint, logger)
608

609
    # Register sleep/wake_up engine routes
610
    runtime.register_engine_route("sleep", handler.sleep)
611
612
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
613

614
615
616
617
618
    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return

619
    # Register prefill model with ModelType.Prefill
620
621
622
623
624
625
626
627
628
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
    await register_vllm_model(
        model_input,
        ModelType.Prefill,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
    )
629

630
631
632
    health_check_payload = VllmPrefillHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
633

Alec's avatar
Alec committed
634
    try:
635
        logger.debug("Starting serve_endpoint for prefill worker")
Alec's avatar
Alec committed
636
        await asyncio.gather(
637
638
639
640
            # for prefill, we want to shutdown the engine after all prefill requests are finished because
            #     (temp reason): we don't support re-routing prefill requests
            #     (long-term reason): prefill engine should pull from a global queue so there is
            #                         only a few in-flight requests that can be quickly finished
641
642
643
            generate_endpoint.serve_endpoint(
                handler.generate,
                graceful_shutdown=True,
644
                # In practice config.served_model_name is always set, but mypy needs the "or" here.
645
646
647
648
649
650
651
652
653
654
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
655
                health_check_payload=health_check_payload,
656
657
            ),
            clear_endpoint.serve_endpoint(
658
                handler.clear_kv_blocks,
659
660
661
662
                metrics_labels=[
                    (prometheus_names.labels.MODEL, config.served_model_name),
                    (prometheus_names.labels.MODEL_NAME, config.served_model_name),
                ],
663
            ),
Alec's avatar
Alec committed
664
        )
665
        logger.debug("serve_endpoint completed for prefill worker")
Alec's avatar
Alec committed
666
667
668
669
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
670
        logger.debug("Cleaning up prefill worker")
Alec's avatar
Alec committed
671
672
673
        handler.cleanup()


674
async def init(
675
676
677
678
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    pre_created_engine=None,
679
):
Alec's avatar
Alec committed
680
681
682
683
    """
    Instantiate and serve
    """

684
685
686
687
    generate_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.{config.endpoint}"
    )
    component = generate_endpoint.component()
Alec's avatar
Alec committed
688
    clear_endpoint = component.endpoint("clear_kv_blocks")
689
690
691
692
693
694

    lora_enabled = config.engine_args.enable_lora
    if lora_enabled:
        load_lora_endpoint = component.endpoint("load_lora")
        unload_lora_endpoint = component.endpoint("unload_lora")
        list_loras_endpoint = component.endpoint("list_loras")
Alec's avatar
Alec committed
695

696
    model_name = config.served_model_name or config.model
697
698
699
700
701
702
703
704

    # Use pre-created engine if provided (checkpoint mode), otherwise create new
    if pre_created_engine is not None:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
705
            component_gauges,
706
        ) = pre_created_engine
707
708
709
710
711
712
713
        # Factory is created after unpack so component_gauges is available
        factory = StatLoggerFactory(
            component,
            component_gauges=component_gauges,
            dp_rank=config.engine_args.data_parallel_rank or 0,
            metrics_labels=[("model", model_name)],
        )
714
    else:
715
716
717
718
719
720
721
722
        # Factory is created without component_gauges; setup_vllm_engine() will
        # create the gauges after setup_multiprocess_prometheus() and set them
        # on the factory before vLLM calls create_stat_logger().
        factory = StatLoggerFactory(
            component,
            dp_rank=config.engine_args.data_parallel_rank or 0,
            metrics_labels=[("model", model_name)],
        )
723
724
725
726
727
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
728
            component_gauges,
729
        ) = setup_vllm_engine(config, factory)
Alec's avatar
Alec committed
730

731
    # TODO Hack to get data, move this to registering in TBD
Alec's avatar
Alec committed
732
733
734
735
    factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
    factory.init_publish()

    handler = DecodeWorkerHandler(
736
737
738
739
        runtime,
        component,
        engine_client,
        default_sampling_params,
740
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
741
        enable_multimodal=config.enable_multimodal,
742
743
        generate_endpoint=generate_endpoint,
        config=config,
744
        use_vllm_tokenizer=config.use_vllm_tokenizer,
745
        shutdown_event=shutdown_event,
746
        enable_frontend_decoding=config.frontend_decoding,
Alec's avatar
Alec committed
747
    )
748
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
749

750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

    # Set up KV event publisher for prefix caching if enabled
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
766
    kv_publishers = setup_kv_event_publisher(
767
768
769
770
771
772
        config,
        component,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
Yan Ru Pei's avatar
Yan Ru Pei committed
773
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
774
775
    if kv_publishers:
        handler.kv_publishers = kv_publishers
776

777
    setup_metrics_collection(config, generate_endpoint, logger)
778

779
    # Register sleep/wake_up engine routes
780
    runtime.register_engine_route("sleep", handler.sleep)
781
782
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
783

784
785
786
787
    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return
788

789
790
791
    # Parse endpoint types from --endpoint-types flag
    model_type = parse_endpoint_types(config.endpoint_types)
    logger.info(f"Registering model with endpoint types: {config.endpoint_types}")
792

793
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
794

795
    # Warn if custom template provided but chat endpoint not enabled
796
    if config.custom_jinja_template and "chat" not in config.endpoint_types:
797
798
799
        logger.warning(
            "Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
            "The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
800
801
        )

802
803
804
805
806
807
808
809
810
    await register_vllm_model(
        model_input,
        model_type,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
    )

811
812
813
    health_check_payload = VllmHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
814

Alec's avatar
Alec committed
815
    try:
816
        logger.debug("Starting serve_endpoint for decode worker")
817
818
819
820
821
822
823
824
825
826
827
828
829

        model_metrics_labels = [
            (
                prometheus_names.labels.MODEL,
                config.served_model_name or config.model,
            ),
            (
                prometheus_names.labels.MODEL_NAME,
                config.served_model_name or config.model,
            ),
        ]

        serve_tasks = [
830
831
            # for decode, we want to transfer the in-flight requests to other decode engines,
            # because waiting them to finish can take a long time for long OSLs
832
833
            generate_endpoint.serve_endpoint(
                handler.generate,
834
                graceful_shutdown=True,
835
                metrics_labels=model_metrics_labels,
836
                health_check_payload=health_check_payload,
837
838
            ),
            clear_endpoint.serve_endpoint(
839
                handler.clear_kv_blocks,
840
                metrics_labels=model_metrics_labels,
841
            ),
842
843
844
845
846
847
848
849
        ]

        if lora_enabled:
            serve_tasks.extend(
                [
                    load_lora_endpoint.serve_endpoint(
                        handler.load_lora,
                        metrics_labels=model_metrics_labels,
850
                    ),
851
852
853
                    unload_lora_endpoint.serve_endpoint(
                        handler.unload_lora,
                        metrics_labels=model_metrics_labels,
854
                    ),
855
856
857
                    list_loras_endpoint.serve_endpoint(
                        handler.list_loras,
                        metrics_labels=model_metrics_labels,
858
                    ),
859
860
861
862
                ]
            )

        await asyncio.gather(*serve_tasks)
863
        logger.debug("serve_endpoint completed for decode worker")
Alec's avatar
Alec committed
864
865
866
867
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
868
        logger.debug("Cleaning up decode worker")
Alec's avatar
Alec committed
869
870
871
872
        # Cleanup background tasks
        handler.cleanup()


873
874
875
876
877
878
879
def get_engine_cache_info(engine: AsyncLLM):
    """Retrieve cache configuration information from [`AsyncLLM`] engine."""

    try:
        # Get values directly from vllm_config instead of collective_rpc
        cache_values = {
            "num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks,
880
            "block_size": engine.vllm_config.cache_config.block_size,
881
882
883
884
885
886
887
888
889
890
891
        }

        scheduler_values = {
            "max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs,
            "max_num_batched_tokens": engine.vllm_config.scheduler_config.max_num_batched_tokens,
        }

        logging.info(f"Cache config values: {cache_values}")
        logging.info(f"Scheduler config values: {scheduler_values}")
        return {
            "num_gpu_blocks": cache_values["num_gpu_blocks"],
892
            "block_size": cache_values["block_size"],
893
894
895
896
897
898
899
900
            "max_num_seqs": scheduler_values["max_num_seqs"],
            "max_num_batched_tokens": scheduler_values["max_num_batched_tokens"],
        }
    except Exception as e:
        logging.error(f"Failed to get configuration values from vLLM config: {e}")
        raise


901
902
903
async def init_omni(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
904
    """Initialize Omni worker for multi-stage pipeline generation using vLLM-Omni.
905

906
907
    Supports text-to-text, text-to-image, and text-to-video generation
    through a single unified OmniHandler.
908
909
910
    """
    from dynamo.vllm.omni import OmniHandler

911
912
913
914
    generate_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.{config.endpoint}"
    )
    component = generate_endpoint.component()
915

916
917
918
919
920
    # Initialize media filesystem for storing generated images/videos
    media_fs = (
        get_fs(config.media_output_fs_url) if config.media_output_fs_url else None
    )

921
    # Initialize unified OmniHandler
922
923
924
925
    handler = OmniHandler(
        runtime=runtime,
        component=component,
        config=config,
926
        default_sampling_params={},
927
        shutdown_event=shutdown_event,
928
929
        media_output_fs=media_fs,
        media_output_http_url=config.media_output_http_url,
930
931
932
933
934
935
936
937
938
939
940
941
942
    )

    logger.info(f"Omni worker initialized for model: {config.model}")

    # Set up metrics collection for vLLM and LMCache metrics
    setup_metrics_collection(config, generate_endpoint, logger)

    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return

    # TODO: extend for multi-stage pipelines
943
944
945
946
    model_type = get_output_modalities(config.output_modalities, config.model)
    if model_type is None:
        # Default to Images
        model_type = ModelType.Images
947
    await register_model(
948
        ModelInput.Text,
949
        model_type,
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
    )

    logger.info("Starting to serve Omni worker endpoint...")

    health_check_payload = (
        await VllmOmniHealthCheckPayload.create(handler.engine_client)
    ).to_dict()

    try:
        await generate_endpoint.serve_endpoint(
            handler.generate,
            graceful_shutdown=True,
966
967
968
969
970
971
972
973
974
975
            metrics_labels=[
                (
                    prometheus_names.labels.MODEL,
                    config.served_model_name or config.model,
                ),
                (
                    prometheus_names.labels.MODEL_NAME,
                    config.served_model_name or config.model,
                ),
            ],
976
977
978
979
980
981
982
983
984
985
            health_check_payload=health_check_payload,
        )
    except Exception as e:
        logger.error(f"Failed to serve Omni endpoint: {e}")
        raise
    finally:
        logger.debug("Cleaning up Omni worker")
        handler.cleanup()


Alec's avatar
Alec committed
986
987
988
989
def main():
    uvloop.run(worker())


Alec's avatar
Alec committed
990
if __name__ == "__main__":
Alec's avatar
Alec committed
991
    main()