main.py 39.2 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Alec's avatar
Alec committed
2
3
# SPDX-License-Identifier: Apache-2.0

4
import argparse
Alec's avatar
Alec committed
5
6
7
import asyncio
import logging
import os
8
import tempfile
9
import time
jh-nv's avatar
jh-nv committed
10
from typing import Any, Optional
Alec's avatar
Alec committed
11
12

import uvloop
13
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
14
from vllm.config import VllmConfig
Alec's avatar
Alec committed
15
16
17
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
18
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
Alec's avatar
Alec committed
19

20
from dynamo import prometheus_names
21
from dynamo.common.config_dump import dump_config
22
from dynamo.common.utils.endpoint_types import parse_endpoint_types
23
from dynamo.common.utils.graceful_shutdown import install_signal_handlers
24
25
26
27
from dynamo.common.utils.prometheus import (
    LLMBackendMetrics,
    register_engine_metrics_callback,
)
28
from dynamo.common.utils.runtime import create_runtime
Alec's avatar
Alec committed
29
from dynamo.llm import (
30
    KvEventPublisher,
31
    ModelInput,
32
    ModelRuntimeConfig,
Alec's avatar
Alec committed
33
    ModelType,
34
35
    fetch_model,
    register_model,
Alec's avatar
Alec committed
36
)
37
from dynamo.runtime import DistributedRuntime, Endpoint
Alec's avatar
Alec committed
38
from dynamo.runtime.logging import configure_dynamo_logging
39
from dynamo.vllm.worker_factory import WorkerFactory
Alec's avatar
Alec committed
40

41
from . import envs
42
from .args import Config, _uses_dynamo_connector, parse_args
43
from .constants import DisaggregationMode
44
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler, get_dp_range_for_worker
Ayush Agarwal's avatar
Ayush Agarwal committed
45
from .health_check import VllmHealthCheckPayload, VllmPrefillHealthCheckPayload
46
from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
47
from .snapshot import prepare_snapshot_engine
Alec's avatar
Alec committed
48

jh-nv's avatar
jh-nv committed
49
50
51
52
53
54
55
56
57
58
59
60
# Optional imports for frontend decoding support
MediaDecoder: type | None = None
MediaFetcher: type | None = None
try:
    from dynamo.llm import MediaDecoder, MediaFetcher

    MEDIA_DECODER_AVAILABLE = True
except ImportError:
    MediaDecoder = None
    MediaFetcher = None
    MEDIA_DECODER_AVAILABLE = False

Alec's avatar
Alec committed
61
62
configure_dynamo_logging()
logger = logging.getLogger(__name__)
63
shutdown_endpoints: list = []
Alec's avatar
Alec committed
64
65


66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def build_headless_namespace(config: Config) -> argparse.Namespace:
    """Build an argparse Namespace from engine_args for vLLM's run_headless().

    run_headless() expects the raw CLI namespace. We reconstruct it from
    the already-parsed AsyncEngineArgs so parse_args() doesn't need to
    leak transport details.
    """
    ns = argparse.Namespace(**vars(config.engine_args))
    # run_headless() reads api_server_count; default to 0 (no API server)
    if not hasattr(ns, "api_server_count"):
        ns.api_server_count = 0
    return ns


def run_dynamo_headless(config: Config) -> None:
    """Run in headless mode for multi-node TP/PP.

    Secondary nodes spawn vLLM workers only — no engine core, no scheduler,
    no Dynamo endpoints. Bypasses DistributedRuntime entirely (no NATS/etcd).
    """
86
87
88
89
    # Keep the upstream CLI import local so tests that only exercise
    # build_headless_namespace() do not pull in vLLM's full CLI import graph.
    from vllm.entrypoints.cli.serve import run_headless

90
91
92
93
    args = build_headless_namespace(config)
    run_headless(args)


jh-nv's avatar
jh-nv committed
94
async def worker() -> None:
Alec's avatar
Alec committed
95
96
    config = parse_args()

97
98
99
100
101
102
103
104
105
106
    dump_config(config.dump_config_to, config)

    # Name the model. Use either the full path (vllm and sglang do the same),
    # or the HF name (e.g. "Qwen/Qwen3-0.6B"), depending on cmd line params.
    if not config.served_model_name:
        config.served_model_name = config.engine_args.served_model_name = config.model

    # Download the model if necessary using modelexpress.
    # We want it on disk before we start vllm to avoid downloading from HuggingFace.
    #
107
    # We don't set `config.engine_args.model` to the local path fetch_model returns
108
109
110
111
112
113
    # because vllm will send that name to its Ray pipeline-parallel workers, which
    # may not have the local path.
    # vllm will attempt to download the model again, but find it in the HF cache.
    # For non-HF models use a path instead of an HF name, and ensure all workers have
    # that path (ideally via a shared folder).
    if not os.path.exists(config.model):
114
        await fetch_model(config.model)
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    # CHECKPOINT MODE: Load engine BEFORE runtime creation
    # This allows checkpointing GPU state before runtime connections are established
    snapshot_controller = await prepare_snapshot_engine(
        config,
        setup_vllm_engine,
    )

    snapshot_engine = None
    if snapshot_controller is not None:
        snapshot_engine = snapshot_controller.engine
        (
            config.namespace,
            config.discovery_backend,
        ) = snapshot_controller.reload_restore_identity()

131
132
133
134
135
136
    # HEADLESS MODE: bypass DistributedRuntime entirely.
    # Workers run vLLM only (no NATS, etcd, or dynamo endpoints).
    if config.headless:
        run_dynamo_headless(config)
        return

137
    shutdown_event = asyncio.Event()
138
    runtime, loop = create_runtime(
139
        discovery_backend=config.discovery_backend,
140
141
142
        request_plane=config.request_plane,
        event_plane=config.event_plane,
        use_kv_events=config.use_kv_events,
143
144
    )

145
146
    install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event)

147
    # Route to appropriate initialization based on config flags
148
149
150
151
152
153
154
155
    if WorkerFactory.handles(config):
        # Create worker factory with setup functions
        factory = WorkerFactory(
            setup_vllm_engine_fn=setup_vllm_engine,
            setup_kv_event_publisher_fn=setup_kv_event_publisher,
            register_vllm_model_fn=register_vllm_model,
        )
        await factory.create(
156
157
158
159
            runtime,
            config,
            shutdown_event,
            shutdown_endpoints,
160
            snapshot_engine=snapshot_engine,
161
        )
162
        logger.debug("multimodal worker completed")
163
    elif config.disaggregation_mode == DisaggregationMode.PREFILL:
164
        await init_prefill(
165
166
167
            runtime,
            config,
            shutdown_event,
168
            snapshot_engine=snapshot_engine,
169
        )
170
        logger.debug("init_prefill completed")
Alec's avatar
Alec committed
171
    else:
172
        await init(
173
174
175
            runtime,
            config,
            shutdown_event,
176
            snapshot_engine=snapshot_engine,
177
        )
178
179
180
        logger.debug("init completed")

    logger.debug("Worker function completed, exiting...")
Alec's avatar
Alec committed
181
182


jh-nv's avatar
jh-nv committed
183
184
185
def setup_metrics_collection(
    config: Config, generate_endpoint: Endpoint, logger: logging.Logger
) -> None:
186
187
188
189
190
191
192
193
194
195
196
197
198
    """Set up metrics collection for vLLM and LMCache metrics.

    In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored:
      1. In-memory: Metric objects in global REGISTRY
      2. On-disk: Metric values in .db files (PROMETHEUS_MULTIPROC_DIR)

    MultiProcessCollector reads from .db files but adding it to REGISTRY can fail
    with "Duplicated timeseries" if PROMETHEUS_MULTIPROC_DIR was set before process
    started (K8s deployments) because metrics are already in REGISTRY.

    Solution: Try adding MultiProcessCollector to REGISTRY. If that fails, use
    separate registry for multiprocess collection and register callbacks to both
    registries to ensure all metrics (vllm, lmcache, dynamo_component) are collected.
199
200
201
202
203

    Auto-label injection:
        Hierarchy labels (dynamo_namespace, dynamo_component, dynamo_endpoint) are automatically
        injected into engine metrics to align Python metrics with Rust auto-labels.
        Additional labels can be provided via inject_labels parameter.
204
205
    """
    if config.engine_args.disable_log_stats is False:
206
207
208
209
210
211
212
213
214
215
216
        # Register the dedicated dynamo_component registry callback
        # IMPORTANT: We do NOT use MultiProcessCollector for DYNAMO_COMPONENT_REGISTRY
        # because our gauges use in-memory values which work fine for single-process
        # and multi-process (each process has its own gauge with dp_rank label).
        # Using MultiProcessCollector would read from .db files which causes stale
        # values to accumulate across test runs.
        register_engine_metrics_callback(
            endpoint=generate_endpoint,
            registry=DYNAMO_COMPONENT_REGISTRY,
        )

217
218
219
220
221
222
223
224
225
        if os.environ.get("PROMETHEUS_MULTIPROC_DIR"):
            try:
                # MultiProcessCollector reads metrics from .db files in PROMETHEUS_MULTIPROC_DIR
                # Adding it to REGISTRY allows collecting both in-memory and .db file metrics
                multiprocess.MultiProcessCollector(REGISTRY)
                logger.debug("Added MultiProcessCollector to global REGISTRY")
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
226
227
228
229
230
                    metric_prefix_filters=["vllm:", "lmcache:"],
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
231
232
233
234
235
236
237
238
239
240
241
                )
            except ValueError as e:
                # Conflict: metrics already in REGISTRY, MultiProcessCollector tries to add same metrics from .db files
                # Solution: Use separate registry that ONLY reads from .db files (no in-memory conflicts)
                logger.debug(
                    f"Could not add MultiProcessCollector to REGISTRY ({e}), using separate registry"
                )
                multiproc_registry = CollectorRegistry()
                multiprocess.MultiProcessCollector(multiproc_registry)

                # Register both registries to collect all metrics
242
                # Global REGISTRY has in-memory metrics (vllm)
243
244
245
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
246
                    metric_prefix_filters=["vllm:"],
247
248
249
250
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
251
252
253
254
255
                )
                # Multiproc registry has .db file metrics (lmcache, possibly vllm duplicates)
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=multiproc_registry,
256
257
258
259
260
                    metric_prefix_filters=["vllm:", "lmcache:"],
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
261
262
263
264
265
266
267
                )
        else:
            # No multiprocess mode
            register_engine_metrics_callback(
                endpoint=generate_endpoint,
                registry=REGISTRY,
                metric_prefix_filters=["vllm:", "lmcache:"],
268
269
270
271
                namespace_name=config.namespace,
                component_name=config.component,
                endpoint_name=config.endpoint,
                model_name=config.model,
272
273
274
            )


Yan Ru Pei's avatar
Yan Ru Pei committed
275
276
def setup_kv_event_publisher(
    config: Config,
277
278
    generate_endpoint: Endpoint,
    vllm_config: VllmConfig,
279
280
    consolidator_enabled: bool = False,
    consolidator_port: Optional[int] = 5558,
jh-nv's avatar
jh-nv committed
281
) -> Optional[list[KvEventPublisher]]:
Yan Ru Pei's avatar
Yan Ru Pei committed
282
    """
jh-nv's avatar
jh-nv committed
283
    list[KvEventPublisher] | None
Yan Ru Pei's avatar
Yan Ru Pei committed
284
285
    Set up KV event publishers for prefix caching if enabled.
    Creates one publisher per dp_rank since each dp_rank publishes to a different port.
286
287
288
289
290
291
292
    Args:
        config: Worker configuration
        generate_endpoint: Endpoint for worker ID
        vllm_config: vLLM configuration
        consolidator_enabled: If True, subscribe to kv eventconsolidator's ZMQ endpoint
        consolidator_port: Port where kv event consolidator publishes (default: 5558)

Yan Ru Pei's avatar
Yan Ru Pei committed
293
    Returns:
294
        List of KvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
Yan Ru Pei's avatar
Yan Ru Pei committed
295
296
297
298
    """
    if not config.engine_args.enable_prefix_caching:
        return None

299
    # Skip KV event publishing for decode workers
300
    if config.disaggregation_mode == DisaggregationMode.DECODE:
301
302
303
        logger.info("Skipping KV event publisher setup for decode worker")
        return None

304
305
306
    if config.engine_args.kv_events_config is None:
        return None

307
308
309
310
311
312
313
    # Check if kv_cache_events are explicitly disabled
    if not config.engine_args.kv_events_config.enable_kv_cache_events:
        logger.info(
            "KV event publishing skipped: enable_kv_cache_events=False in kv_events_config"
        )
        return None

314
315
316
    # Get DP rank range managed by this worker to create publishers for corresponding dp_ranks,
    # all served workers should cover all ranks.
    dp_start, dp_size = get_dp_range_for_worker(vllm_config)
Yan Ru Pei's avatar
Yan Ru Pei committed
317
318
    kv_publishers = []

319
    for dp_rank in range(dp_start, dp_start + dp_size):
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        if consolidator_enabled:
            # TODO: Use different port for each dp_rank once KVBM supports DP
            zmq_endpoint = f"tcp://127.0.0.1:{consolidator_port}"
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to consolidator at {zmq_endpoint}"
            )
        else:
            # Each dp_rank publishes to a different port
            zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
                config.engine_args.kv_events_config.endpoint,
                data_parallel_rank=dp_rank,
            ).replace("*", "127.0.0.1")
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to vLLM at {zmq_endpoint}"
            )
Yan Ru Pei's avatar
Yan Ru Pei committed
335

336
        kv_publisher = KvEventPublisher(
337
            endpoint=generate_endpoint,
Yan Ru Pei's avatar
Yan Ru Pei committed
338
339
            kv_block_size=vllm_config.cache_config.block_size,
            zmq_endpoint=zmq_endpoint,
340
            zmq_topic="",
341
            enable_local_indexer=config.enable_local_indexer,
342
            dp_rank=dp_rank,
Yan Ru Pei's avatar
Yan Ru Pei committed
343
344
        )
        kv_publishers.append(kv_publisher)
Yan Ru Pei's avatar
Yan Ru Pei committed
345

Yan Ru Pei's avatar
Yan Ru Pei committed
346
347
348
        logger.info(
            f"Worker reading KV events for dp_rank={dp_rank} from {zmq_endpoint}"
        )
Yan Ru Pei's avatar
Yan Ru Pei committed
349

Yan Ru Pei's avatar
Yan Ru Pei committed
350
    return kv_publishers if kv_publishers else None
Yan Ru Pei's avatar
Yan Ru Pei committed
351
352


353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
def setup_fpm_relay(
    generate_endpoint: Endpoint,
    vllm_config: VllmConfig,
) -> Optional[list]:
    """
    Set up forward pass metrics relays for the event plane.

    Creates one FpmEventRelay per dp_rank. Each relay subscribes to the
    local raw ZMQ PUB from InstrumentedScheduler (in the EngineCore child
    process) and re-publishes to the Dynamo event plane with automatic
    discovery registration.

    Returns:
        List of FpmEventRelay instances, or None if FPM is not enabled.
    """
    if not envs.is_set("DYN_FORWARDPASS_METRIC_PORT"):
        return None

    try:
        from dynamo.llm import FpmEventRelay
    except ImportError:
        logger.warning(
            "FpmEventRelay not available (Rust bindings not built with FPM support). "
            "Forward pass metrics will not be relayed to the event plane."
        )
        return None

    dp_start, dp_size = get_dp_range_for_worker(vllm_config)
    relays = []

    for dp_rank in range(dp_start, dp_start + dp_size):
        base_port = envs.DYN_FORWARDPASS_METRIC_PORT
        zmq_endpoint = f"tcp://127.0.0.1:{base_port + dp_rank}"

        relay = FpmEventRelay(
            endpoint=generate_endpoint,
            zmq_endpoint=zmq_endpoint,
        )
        relays.append(relay)

        logger.info(f"FPM relay for dp_rank={dp_rank} subscribing to {zmq_endpoint}")

    return relays if relays else None


jh-nv's avatar
jh-nv committed
398
def setup_vllm_engine(
399
400
401
    config: Config,
    stat_logger: Optional[StatLoggerFactory] = None,
    fpm_worker_id: Optional[str] = None,
jh-nv's avatar
jh-nv committed
402
) -> tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]:
403
404
405
    # vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
    # instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
    # ourselves to avoid this and handle cleanup properly.
406
407
408
409
410
411
412
413
    prometheus_temp_dir = None
    if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
        prometheus_temp_dir = tempfile.TemporaryDirectory(prefix="vllm_prometheus_")
        os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_temp_dir.name
        logger.debug(
            f"Created PROMETHEUS_MULTIPROC_DIR at: {os.environ['PROMETHEUS_MULTIPROC_DIR']}"
        )

414
    setup_multiprocess_prometheus()  # call vLLM's library's function to setup multiprocess prometheus
415
416
417
418
    logger.debug(
        f"Prometheus multiproc dir set to: {os.environ.get('PROMETHEUS_MULTIPROC_DIR')}"
    )

419
420
421
422
423
424
425
426
427
428
429
430
431
    # Construct Prometheus gauges AFTER setup_multiprocess_prometheus() so Gauge objects
    # see the correct ValueClass (multiprocess vs in-memory).
    component_gauges = LLMBackendMetrics(
        registry=DYNAMO_COMPONENT_REGISTRY,
        model_name=config.served_model_name or "",
        component_name=config.component or "",
    )

    # If a StatLoggerFactory was provided, give it the gauges so the loggers
    # it creates can publish Prometheus metrics.
    if stat_logger is not None:
        stat_logger.component_gauges = component_gauges

Alec's avatar
Alec committed
432
433
434
435
    os.environ["VLLM_NO_USAGE_STATS"] = "1"  # Avoid internal HTTP requests
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

    engine_args = config.engine_args
436

437
438
439
440
441
    if engine_args.enable_lora:
        if "VLLM_ALLOW_RUNTIME_LORA_UPDATING" not in os.environ:
            os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
        if "VLLM_LORA_MODULES_LOADING_TIMEOUT" not in os.environ:
            os.environ["VLLM_LORA_MODULES_LOADING_TIMEOUT"] = "600"
442
443

    if engine_args.load_format == "gms":
444
        engine_args.worker_cls = "gpu_memory_service.integrations.vllm.worker.GMSWorker"
445

446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    if engine_args.load_format in ("mx-source", "mx-target"):
        try:
            from modelexpress import register_modelexpress_loaders

            # Ensure the ModelExpress server URL env var is set for the model loader
            if config.model_express_url:
                os.environ["MODEL_EXPRESS_URL"] = config.model_express_url
            register_modelexpress_loaders()
            # Use wrapper worker to ensure loaders are registered in spawned worker processes
            engine_args.worker_cls = "modelexpress.vllm_worker.ModelExpressWorker"
        except ImportError as e:
            raise ImportError(
                f"ModelExpress package required for --load-format={engine_args.load_format}. "
                "Install with: pip install modelexpress"
            ) from e

Alec's avatar
Alec committed
462
463
464
465
466
    # Load default sampling params from `generation_config.json`
    default_sampling_params = (
        engine_args.create_model_config().get_diff_sampling_param()
    )

467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
    # Configure ec_both mode with DynamoMultimodalEmbeddingCacheConnector.
    # Must happen BEFORE engine setup so vLLM sees ec_transfer_config.
    if (
        not config.route_to_encoder
        and config.multimodal_embedding_cache_capacity_gb > 0
    ):
        from vllm.config import ECTransferConfig

        logger.info(
            "Configuring ec_both mode with DynamoMultimodalEmbeddingCacheConnector "
            "(capacity=%.2f GB)",
            config.multimodal_embedding_cache_capacity_gb,
        )
        instance_id = 0
        engine_id = f"{config.namespace}.{config.component}.backend.{instance_id}"
        engine_args.ec_transfer_config = ECTransferConfig(
            engine_id=engine_id,
            ec_role="ec_both",
            ec_connector="DynamoMultimodalEmbeddingCacheConnector",
            ec_connector_module_path="dynamo.vllm.multimodal_utils.multimodal_embedding_cache_connector",
            ec_connector_extra_config={
                "multimodal_embedding_cache_capacity_gb": config.multimodal_embedding_cache_capacity_gb,
            },
        )
        logger.info("Configured ec_both with engine_id=%s", engine_id)

Alec's avatar
Alec committed
493
494
495
496
    # Taken from build_async_engine_client_from_engine_args()
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

497
    # Set up consolidator endpoints if KVBM (DynamoConnector) is enabled
498
    consolidator_endpoints = None
499
    if _uses_dynamo_connector(config.engine_args):
500
501
502
503
504
505
506
507
508
509
510
511
        try:
            from kvbm.vllm_integration.consolidator_config import (
                get_consolidator_endpoints,
            )

            consolidator_endpoints = get_consolidator_endpoints(vllm_config)
        except Exception as e:
            logger.warning(
                f"KVBM connector is enabled but failed to get consolidator endpoints: {e}. "
                "Continuing without KV event consolidation. "
                "Ensure 'kvbm' package is installed if this feature is needed."
            )
512
513
514
    # Store consolidator endpoints in additional_config (vLLM 0.16+ uses strict
    # dataclass fields; monkey-patching attributes onto VllmConfig is no longer safe).
    vllm_config.additional_config["consolidator_endpoints"] = consolidator_endpoints
515

516
517
518
519
    # Pass worker identity to InstrumentedScheduler via additional_config.
    if fpm_worker_id is not None:
        vllm_config.additional_config["fpm_worker_id"] = fpm_worker_id

Alec's avatar
Alec committed
520
521
522
523
    factory = []
    if stat_logger:
        factory.append(stat_logger)

524
525
    # Time engine initialization
    start_time = time.time()
Alec's avatar
Alec committed
526
527
528
529
    engine_client = AsyncLLM.from_vllm_config(
        vllm_config=vllm_config,
        usage_context=usage_context,
        stat_loggers=factory,
530
        enable_log_requests=engine_args.enable_log_requests,
Alec's avatar
Alec committed
531
532
        disable_log_stats=engine_args.disable_log_stats,
    )
533
534
535
536
    load_time = time.time() - start_time

    # Record model load time
    component_gauges.set_model_load_time(load_time)
537
538

    logger.info(f"VllmWorker for {config.served_model_name} has been initialized")
539

540
541
542
543
544
545
546
    return (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
        component_gauges,
    )
Alec's avatar
Alec committed
547
548


549
550
551
async def register_vllm_model(
    model_input: ModelInput,
    model_type: ModelType,
jh-nv's avatar
jh-nv committed
552
    generate_endpoint: Endpoint,
553
554
    config: Config,
    engine_client: AsyncLLM,
555
    vllm_config: VllmConfig,
jh-nv's avatar
jh-nv committed
556
) -> None:
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
    """
    Helper function to register a vLLM model with runtime configuration.

    Args:
        model_input: Input type for the model (e.g., ModelInput.Tokens)
        model_type: Type of model (e.g., ModelType.Chat, ModelType.Prefill)
        generate_endpoint: Endpoint to register
        config: Configuration object
        engine_client: vLLM engine client
        vllm_config: vLLM configuration
    """
    runtime_config = ModelRuntimeConfig()

    # Get runtime configuration from vLLM engine
    logging.info(
        f"Getting engine runtime configuration metadata from vLLM engine for {model_type}..."
    )
    runtime_values = get_engine_cache_info(engine_client)
    runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
    runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
    runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
578
579
    # Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
    runtime_config.enable_local_indexer = (
580
581
        config.enable_local_indexer
        and config.disaggregation_mode != DisaggregationMode.DECODE
582
    )
583
584
585

    # Add tool/reasoning parsers for decode models
    if model_type != ModelType.Prefill:
586
587
        runtime_config.tool_call_parser = config.dyn_tool_call_parser
        runtime_config.reasoning_parser = config.dyn_reasoning_parser
588
589

    # Get data_parallel_size from vllm_config (defaults to 1)
590
591
592
    dp_range = get_dp_range_for_worker(vllm_config)
    runtime_config.data_parallel_start_rank = dp_range[0]
    runtime_config.data_parallel_size = dp_range[1]
593

594
595
596
597
598
599
600
601
602
603
    # Configure media decoder for frontend image decoding when enabled
    # This enables frontend to decode images and transfer via NIXL RDMA
    media_decoder = None
    media_fetcher = None
    if config.frontend_decoding:
        if not MEDIA_DECODER_AVAILABLE:
            raise RuntimeError(
                "--frontend-decoding requires MediaDecoder support. "
                "Ensure dynamo.llm module includes MediaDecoder and MediaFetcher."
            )
jh-nv's avatar
jh-nv committed
604
        assert MediaDecoder is not None and MediaFetcher is not None
605
606
607
608
609
610
        media_decoder = MediaDecoder()
        media_decoder.enable_image({"limits": {"max_alloc": 128 * 1024 * 1024}})
        # media_decoder.enable_video({})

        media_fetcher = MediaFetcher()
        media_fetcher.timeout_ms(30000)
611
        media_fetcher.allow_direct_port(True)
612

613
    await register_model(
614
615
616
617
618
        model_input,
        model_type,
        generate_endpoint,
        config.model,
        config.served_model_name,
619
        context_length=vllm_config.model_config.max_model_len,
620
        kv_cache_block_size=runtime_values["block_size"],
621
622
        runtime_config=runtime_config,
        custom_template_path=config.custom_jinja_template,
623
624
        media_decoder=media_decoder,
        media_fetcher=media_fetcher,
625
626
627
    )


628
async def init_prefill(
629
630
631
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
632
    snapshot_engine: Optional[
jh-nv's avatar
jh-nv committed
633
634
635
        tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]
    ] = None,
) -> None:
Alec's avatar
Alec committed
636
637
638
    """
    Instantiate and serve
    """
639
640
641
    generate_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.{config.endpoint}"
    )
642
643
644
    clear_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.clear_kv_blocks"
    )
Alec's avatar
Alec committed
645

646
    # Use pre-created engine if provided (checkpoint mode), otherwise create new
647
    fpm_worker_id = str(generate_endpoint.connection_id())
648
    if snapshot_engine is not None:
649
650
651
652
653
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
654
            _component_gauges,
655
        ) = snapshot_engine
656
657
658
659
660
        # TODO: The scheduler in the child process still has worker_id=""
        # because the engine was forked before the runtime existed.
        # Propagating the new ID to the child requires shared memory or
        # a restart of the EngineCore process.
        vllm_config.additional_config["fpm_worker_id"] = fpm_worker_id
661
662
663
664
665
666
    else:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
667
            _component_gauges,
668
        ) = setup_vllm_engine(config, fpm_worker_id=fpm_worker_id)
Alec's avatar
Alec committed
669

670
    handler = PrefillWorkerHandler(
671
672
673
674
        runtime,
        engine_client,
        default_sampling_params,
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
675
        enable_multimodal=config.enable_multimodal,
676
677
        generate_endpoint=generate_endpoint,
        config=config,
678
        use_vllm_tokenizer=config.use_vllm_tokenizer,
679
        shutdown_event=shutdown_event,
680
        enable_frontend_decoding=config.frontend_decoding,
681
    )
682
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
683

684
685
686
687
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

688
689
    _consolidator_eps = vllm_config.additional_config.get("consolidator_endpoints")
    if _consolidator_eps:
690
691
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
692
        consolidator_output_endpoint = _consolidator_eps[2]
693
694
695
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

Yan Ru Pei's avatar
Yan Ru Pei committed
696
    # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
697
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
698
    kv_publishers = setup_kv_event_publisher(
699
700
701
702
703
        config,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
704
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
705
706
    if kv_publishers:
        handler.kv_publishers = kv_publishers
707

708
709
710
711
712
713
714
    # Set up forward pass metrics relay (child ZMQ -> event plane).
    # In checkpoint mode the engine was created before the runtime, so
    # ForwardPassMetrics.worker_id will be empty (relay still works).
    fpm_relays = setup_fpm_relay(generate_endpoint, vllm_config)
    if fpm_relays:
        handler.fpm_relays = fpm_relays

715
    setup_metrics_collection(config, generate_endpoint, logger)
716

717
    # Register sleep/wake_up engine routes
718
    runtime.register_engine_route("sleep", handler.sleep)
719
720
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
721

722
    shutdown_endpoints[:] = [generate_endpoint, clear_endpoint]
723

724
    # Register prefill model with ModelType.Prefill
725
726
727
728
729
730
731
732
733
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
    await register_vllm_model(
        model_input,
        ModelType.Prefill,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
    )
734

735
736
737
    health_check_payload = VllmPrefillHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
738

Alec's avatar
Alec committed
739
    try:
740
        logger.debug("Starting serve_endpoint for prefill worker")
Alec's avatar
Alec committed
741
        await asyncio.gather(
742
743
744
745
            # for prefill, we want to shutdown the engine after all prefill requests are finished because
            #     (temp reason): we don't support re-routing prefill requests
            #     (long-term reason): prefill engine should pull from a global queue so there is
            #                         only a few in-flight requests that can be quickly finished
746
            generate_endpoint.serve_endpoint(
jh-nv's avatar
jh-nv committed
747
                handler.generate,  # type: ignore
748
                graceful_shutdown=True,
749
                # In practice config.served_model_name is always set, but mypy needs the "or" here.
750
751
752
753
754
755
756
757
758
759
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
760
                health_check_payload=health_check_payload,
761
762
            ),
            clear_endpoint.serve_endpoint(
jh-nv's avatar
jh-nv committed
763
                handler.clear_kv_blocks,  # type: ignore
764
                metrics_labels=[
jh-nv's avatar
jh-nv committed
765
766
767
768
769
770
771
772
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
773
                ],
774
            ),
Alec's avatar
Alec committed
775
        )
776
        logger.debug("serve_endpoint completed for prefill worker")
Alec's avatar
Alec committed
777
778
779
780
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
781
        logger.debug("Cleaning up prefill worker")
Alec's avatar
Alec committed
782
783
784
        handler.cleanup()


785
async def init(
786
787
788
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
789
    snapshot_engine: Optional[
jh-nv's avatar
jh-nv committed
790
791
792
        tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]
    ] = None,
) -> None:
Alec's avatar
Alec committed
793
794
795
796
    """
    Instantiate and serve
    """

797
798
799
    generate_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.{config.endpoint}"
    )
800
801
802
    clear_endpoint = runtime.endpoint(
        f"{config.namespace}.{config.component}.clear_kv_blocks"
    )
803

804
805
806
807
808
    shutdown_endpoints[:] = [
        generate_endpoint,
        clear_endpoint,
    ]

809
810
    lora_enabled = config.engine_args.enable_lora
    if lora_enabled:
811
812
813
814
815
816
817
818
819
        load_lora_endpoint = runtime.endpoint(
            f"{config.namespace}.{config.component}.load_lora"
        )
        unload_lora_endpoint = runtime.endpoint(
            f"{config.namespace}.{config.component}.unload_lora"
        )
        list_loras_endpoint = runtime.endpoint(
            f"{config.namespace}.{config.component}.list_loras"
        )
Alec's avatar
Alec committed
820

821
822
823
824
825
826
827
828
        shutdown_endpoints.extend(
            [
                load_lora_endpoint,
                unload_lora_endpoint,
                list_loras_endpoint,
            ]
        )

829
    # Use pre-created engine if provided (checkpoint mode), otherwise create new
830
    fpm_worker_id = str(generate_endpoint.connection_id())
831
    if snapshot_engine is not None:
832
833
834
835
836
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
837
            component_gauges,
838
        ) = snapshot_engine
839
        vllm_config.additional_config["fpm_worker_id"] = fpm_worker_id
840
841
        # Factory is created after unpack so component_gauges is available
        factory = StatLoggerFactory(
842
            endpoint=generate_endpoint,
843
844
            component_gauges=component_gauges,
        )
845
    else:
846
847
848
849
        # Factory is created without component_gauges; setup_vllm_engine() will
        # create the gauges after setup_multiprocess_prometheus() and set them
        # on the factory before vLLM calls create_stat_logger().
        factory = StatLoggerFactory(
850
            endpoint=generate_endpoint,
851
        )
852
853
854
855
856
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
857
            component_gauges,
858
        ) = setup_vllm_engine(config, factory, fpm_worker_id=fpm_worker_id)
Alec's avatar
Alec committed
859

860
    # TODO Hack to get data, move this to registering in TBD
Alec's avatar
Alec committed
861
862
863
864
    factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
    factory.init_publish()

    handler = DecodeWorkerHandler(
865
866
867
        runtime,
        engine_client,
        default_sampling_params,
868
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
869
        enable_multimodal=config.enable_multimodal,
870
871
        generate_endpoint=generate_endpoint,
        config=config,
872
        use_vllm_tokenizer=config.use_vllm_tokenizer,
873
        shutdown_event=shutdown_event,
874
        enable_frontend_decoding=config.frontend_decoding,
Alec's avatar
Alec committed
875
    )
876
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
877

878
879
880
881
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

882
883
    _consolidator_eps = vllm_config.additional_config.get("consolidator_endpoints")
    if _consolidator_eps:
884
885
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
886
        consolidator_output_endpoint = _consolidator_eps[2]
887
888
889
890
891
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

    # Set up KV event publisher for prefix caching if enabled
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
892
    kv_publishers = setup_kv_event_publisher(
893
894
895
896
897
        config,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
Yan Ru Pei's avatar
Yan Ru Pei committed
898
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
899
900
    if kv_publishers:
        handler.kv_publishers = kv_publishers
901

902
903
904
905
906
907
908
    # Set up forward pass metrics relay (child ZMQ -> event plane).
    # In checkpoint mode the engine was created before the runtime, so
    # ForwardPassMetrics.worker_id will be empty (relay still works).
    fpm_relays = setup_fpm_relay(generate_endpoint, vllm_config)
    if fpm_relays:
        handler.fpm_relays = fpm_relays

909
    setup_metrics_collection(config, generate_endpoint, logger)
910

911
    # Register sleep/wake_up engine routes
912
    runtime.register_engine_route("sleep", handler.sleep)
913
914
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
915

916
917
918
    # Parse endpoint types from --endpoint-types flag
    model_type = parse_endpoint_types(config.endpoint_types)
    logger.info(f"Registering model with endpoint types: {config.endpoint_types}")
919

920
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
921

922
    # Warn if custom template provided but chat endpoint not enabled
923
    if config.custom_jinja_template and "chat" not in config.endpoint_types:
924
925
926
        logger.warning(
            "Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
            "The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
927
928
        )

929
930
931
932
933
934
935
936
937
    await register_vllm_model(
        model_input,
        model_type,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
    )

938
939
940
    health_check_payload = VllmHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
941

Alec's avatar
Alec committed
942
    try:
943
        logger.debug("Starting serve_endpoint for decode worker")
944
945
946
947
948
949
950
951
952
953
954
955
956

        model_metrics_labels = [
            (
                prometheus_names.labels.MODEL,
                config.served_model_name or config.model,
            ),
            (
                prometheus_names.labels.MODEL_NAME,
                config.served_model_name or config.model,
            ),
        ]

        serve_tasks = [
957
958
            # for decode, we want to transfer the in-flight requests to other decode engines,
            # because waiting them to finish can take a long time for long OSLs
959
            generate_endpoint.serve_endpoint(
jh-nv's avatar
jh-nv committed
960
                handler.generate,  # type: ignore
961
                graceful_shutdown=True,
962
                metrics_labels=model_metrics_labels,
963
                health_check_payload=health_check_payload,
964
965
            ),
            clear_endpoint.serve_endpoint(
966
                handler.clear_kv_blocks,
967
                metrics_labels=model_metrics_labels,
968
            ),
969
970
971
972
973
974
975
976
        ]

        if lora_enabled:
            serve_tasks.extend(
                [
                    load_lora_endpoint.serve_endpoint(
                        handler.load_lora,
                        metrics_labels=model_metrics_labels,
977
                    ),
978
979
980
                    unload_lora_endpoint.serve_endpoint(
                        handler.unload_lora,
                        metrics_labels=model_metrics_labels,
981
                    ),
982
983
984
                    list_loras_endpoint.serve_endpoint(
                        handler.list_loras,
                        metrics_labels=model_metrics_labels,
985
                    ),
986
987
988
989
                ]
            )

        await asyncio.gather(*serve_tasks)
990
        logger.debug("serve_endpoint completed for decode worker")
Alec's avatar
Alec committed
991
992
993
994
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
995
        logger.debug("Cleaning up decode worker")
Alec's avatar
Alec committed
996
997
998
999
        # Cleanup background tasks
        handler.cleanup()


jh-nv's avatar
jh-nv committed
1000
def get_engine_cache_info(engine: AsyncLLM) -> dict[str, Any]:
1001
1002
1003
1004
1005
1006
    """Retrieve cache configuration information from [`AsyncLLM`] engine."""

    try:
        # Get values directly from vllm_config instead of collective_rpc
        cache_values = {
            "num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks,
1007
            "block_size": engine.vllm_config.cache_config.block_size,
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
        }

        scheduler_values = {
            "max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs,
            "max_num_batched_tokens": engine.vllm_config.scheduler_config.max_num_batched_tokens,
        }

        logging.info(f"Cache config values: {cache_values}")
        logging.info(f"Scheduler config values: {scheduler_values}")
        return {
            "num_gpu_blocks": cache_values["num_gpu_blocks"],
1019
            "block_size": cache_values["block_size"],
1020
1021
1022
1023
1024
1025
1026
1027
            "max_num_seqs": scheduler_values["max_num_seqs"],
            "max_num_batched_tokens": scheduler_values["max_num_batched_tokens"],
        }
    except Exception as e:
        logging.error(f"Failed to get configuration values from vLLM config: {e}")
        raise


jh-nv's avatar
jh-nv committed
1028
def main() -> None:
Alec's avatar
Alec committed
1029
1030
1031
    uvloop.run(worker())


Alec's avatar
Alec committed
1032
if __name__ == "__main__":
Alec's avatar
Alec committed
1033
    main()