main.py 43.1 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Alec's avatar
Alec committed
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging
import os
7
import tempfile
8
import time
9
from typing import Optional
Alec's avatar
Alec committed
10
11

import uvloop
12
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
Alec's avatar
Alec committed
13
14
15
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
16
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
Alec's avatar
Alec committed
17

18
from dynamo import prometheus_names
19
from dynamo.common.config_dump import dump_config
20
from dynamo.common.utils.endpoint_types import parse_endpoint_types
21
from dynamo.common.utils.output_modalities import get_output_modalities
22
23
24
25
from dynamo.common.utils.prometheus import (
    LLMBackendMetrics,
    register_engine_metrics_callback,
)
26
from dynamo.common.utils.runtime import create_runtime
Alec's avatar
Alec committed
27
from dynamo.llm import (
28
    KvEventPublisher,
29
    ModelInput,
30
    ModelRuntimeConfig,
Alec's avatar
Alec committed
31
    ModelType,
32
33
    fetch_model,
    register_model,
Alec's avatar
Alec committed
34
)
35
36
37
38
39
40
41
42
43
44
45

# Optional imports for frontend decoding support
try:
    from dynamo.llm import MediaDecoder, MediaFetcher

    MEDIA_DECODER_AVAILABLE = True
except ImportError:
    MediaDecoder = None
    MediaFetcher = None
    MEDIA_DECODER_AVAILABLE = False

46
from dynamo.runtime import DistributedRuntime
Alec's avatar
Alec committed
47
from dynamo.runtime.logging import configure_dynamo_logging
48
49
from dynamo.vllm.multimodal_handlers import (
    EncodeWorkerHandler,
Ayush Agarwal's avatar
Ayush Agarwal committed
50
    MultimodalDecodeWorkerHandler,
51
52
    MultimodalPDWorkerHandler,
)
Alec's avatar
Alec committed
53

54
from .args import Config, parse_args
55
from .chrek import get_checkpoint_config
Alec's avatar
Alec committed
56
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
57
58
59
60
61
from .health_check import (
    VllmHealthCheckPayload,
    VllmOmniHealthCheckPayload,
    VllmPrefillHealthCheckPayload,
)
62
from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
Alec's avatar
Alec committed
63

Alec's avatar
Alec committed
64
65
configure_dynamo_logging()
logger = logging.getLogger(__name__)
66
CHECKPOINT_SLEEP_MODE_LEVEL = 1
Alec's avatar
Alec committed
67
68


69
70
71
72
73
74
75
76
77
78
79
80
81
async def _handle_non_leader_node(dp_rank: int) -> None:
    """
    Handle non-leader node (data_parallel_rank >= 1) in multi-node deployments.
    Non-leader nodes run vLLM workers but don't serve Dynamo endpoints.
    """
    logger.info(
        f"Non-leader node detected (data_parallel_rank={dp_rank}). "
        "Skipping endpoint serving."
    )
    # Wait indefinitely - process terminated via signal handlers
    await asyncio.Event().wait()


82
async def graceful_shutdown(runtime, shutdown_event):
83
    """
84
85
86
87
    Shutdown dynamo distributed runtime.
    The endpoints will be immediately invalidated so no new requests will be accepted.
    For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
    For endpoints served with graceful_shutdown=False, the serving function will return immediately.
88
    """
89
90
91
92
    logging.info("Received shutdown signal, shutting down DistributedRuntime")
    shutdown_event.set()
    runtime.shutdown()
    logging.info("DistributedRuntime shutdown complete")
93
94


95
async def worker():
Alec's avatar
Alec committed
96
97
    config = parse_args()

98
99
100
101
102
103
104
    dump_config(config.dump_config_to, config)

    # Name the model. Use either the full path (vllm and sglang do the same),
    # or the HF name (e.g. "Qwen/Qwen3-0.6B"), depending on cmd line params.
    if not config.served_model_name:
        config.served_model_name = config.engine_args.served_model_name = config.model

105
106
107
108
    # Check checkpoint mode and validate env vars EARLY (fail fast if misconfigured)
    checkpoint_cfg = get_checkpoint_config()
    if checkpoint_cfg and checkpoint_cfg.checkpoint_exists():
        return
109
110
111
112

    # Download the model if necessary using modelexpress.
    # We want it on disk before we start vllm to avoid downloading from HuggingFace.
    #
113
    # We don't set `config.engine_args.model` to the local path fetch_model returns
114
115
116
117
118
119
    # because vllm will send that name to its Ray pipeline-parallel workers, which
    # may not have the local path.
    # vllm will attempt to download the model again, but find it in the HF cache.
    # For non-HF models use a path instead of an HF name, and ensure all workers have
    # that path (ideally via a shared folder).
    if not os.path.exists(config.model):
120
        await fetch_model(config.model)
121
122
123
124

    # CHECKPOINT MODE: Load engine BEFORE runtime creation
    # This allows checkpointing GPU state before runtime connections are established
    pre_created_engine = None
125
    if checkpoint_cfg is not None:
126
        logger.info(
127
            f"Checkpoint mode enabled (signal_file={checkpoint_cfg.signal_file})"
128
129
        )

130
131
132
        # Checkpoint mode requires sleep mode — enable before engine init
        config.engine_args.enable_sleep_mode = True

133
134
135
        pre_created_engine = setup_vllm_engine(config)
        engine_client = pre_created_engine[0]

136
137
138
        if not await checkpoint_cfg.run_lifecycle(
            engine_client, CHECKPOINT_SLEEP_MODE_LEVEL
        ):
139
            return
140

141
    shutdown_event = asyncio.Event()
142
    runtime, _ = create_runtime(
143
        discovery_backend=config.discovery_backend,
144
145
146
147
        request_plane=config.request_plane,
        event_plane=config.event_plane,
        use_kv_events=config.use_kv_events,
        shutdown_event=shutdown_event,
148
149
    )

150
    # Route to appropriate initialization based on config flags
151
    if config.multimodal_encode_worker:
152
        await init_multimodal_encode_worker(runtime, config, shutdown_event)
153
        logger.debug("init_multimodal_encode_worker completed")
Ayush Agarwal's avatar
Ayush Agarwal committed
154
155
156
157
158
    elif (
        config.multimodal_worker
        or config.multimodal_decode_worker
        or config.multimodal_encode_prefill_worker
    ):
159
160
161
        await init_multimodal_worker(
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
162
        logger.debug("init_multimodal_worker completed")
163
164
165
    elif config.omni:
        await init_omni(runtime, config, shutdown_event)
        logger.debug("init_omni completed")
166
    elif config.is_prefill_worker:
167
168
169
        await init_prefill(
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
170
        logger.debug("init_prefill completed")
Alec's avatar
Alec committed
171
    else:
172
173
174
        await init(
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
175
176
177
        logger.debug("init completed")

    logger.debug("Worker function completed, exiting...")
Alec's avatar
Alec committed
178
179


180
181
182
183
184
185
186
187
188
189
190
191
192
193
def setup_metrics_collection(config: Config, generate_endpoint, logger):
    """Set up metrics collection for vLLM and LMCache metrics.

    In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored:
      1. In-memory: Metric objects in global REGISTRY
      2. On-disk: Metric values in .db files (PROMETHEUS_MULTIPROC_DIR)

    MultiProcessCollector reads from .db files but adding it to REGISTRY can fail
    with "Duplicated timeseries" if PROMETHEUS_MULTIPROC_DIR was set before process
    started (K8s deployments) because metrics are already in REGISTRY.

    Solution: Try adding MultiProcessCollector to REGISTRY. If that fails, use
    separate registry for multiprocess collection and register callbacks to both
    registries to ensure all metrics (vllm, lmcache, dynamo_component) are collected.
194
195
196
197
198

    Auto-label injection:
        Hierarchy labels (dynamo_namespace, dynamo_component, dynamo_endpoint) are automatically
        injected into engine metrics to align Python metrics with Rust auto-labels.
        Additional labels can be provided via inject_labels parameter.
199
200
    """
    if config.engine_args.disable_log_stats is False:
201
202
203
204
205
206
207
208
209
210
211
        # Register the dedicated dynamo_component registry callback
        # IMPORTANT: We do NOT use MultiProcessCollector for DYNAMO_COMPONENT_REGISTRY
        # because our gauges use in-memory values which work fine for single-process
        # and multi-process (each process has its own gauge with dp_rank label).
        # Using MultiProcessCollector would read from .db files which causes stale
        # values to accumulate across test runs.
        register_engine_metrics_callback(
            endpoint=generate_endpoint,
            registry=DYNAMO_COMPONENT_REGISTRY,
        )

212
213
214
215
216
217
218
219
220
        if os.environ.get("PROMETHEUS_MULTIPROC_DIR"):
            try:
                # MultiProcessCollector reads metrics from .db files in PROMETHEUS_MULTIPROC_DIR
                # Adding it to REGISTRY allows collecting both in-memory and .db file metrics
                multiprocess.MultiProcessCollector(REGISTRY)
                logger.debug("Added MultiProcessCollector to global REGISTRY")
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
221
222
223
224
225
                    metric_prefix_filters=["vllm:", "lmcache:"],
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
226
227
228
229
230
231
232
233
234
235
236
                )
            except ValueError as e:
                # Conflict: metrics already in REGISTRY, MultiProcessCollector tries to add same metrics from .db files
                # Solution: Use separate registry that ONLY reads from .db files (no in-memory conflicts)
                logger.debug(
                    f"Could not add MultiProcessCollector to REGISTRY ({e}), using separate registry"
                )
                multiproc_registry = CollectorRegistry()
                multiprocess.MultiProcessCollector(multiproc_registry)

                # Register both registries to collect all metrics
237
                # Global REGISTRY has in-memory metrics (vllm)
238
239
240
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
241
                    metric_prefix_filters=["vllm:"],
242
243
244
245
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
246
247
248
249
250
                )
                # Multiproc registry has .db file metrics (lmcache, possibly vllm duplicates)
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=multiproc_registry,
251
252
253
254
255
                    metric_prefix_filters=["vllm:", "lmcache:"],
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
256
257
258
259
260
261
262
                )
        else:
            # No multiprocess mode
            register_engine_metrics_callback(
                endpoint=generate_endpoint,
                registry=REGISTRY,
                metric_prefix_filters=["vllm:", "lmcache:"],
263
264
265
266
                namespace_name=config.namespace,
                component_name=config.component,
                endpoint_name=config.endpoint,
                model_name=config.model,
267
268
269
            )


Yan Ru Pei's avatar
Yan Ru Pei committed
270
271
272
273
274
def setup_kv_event_publisher(
    config: Config,
    component,
    generate_endpoint,
    vllm_config,
275
276
    consolidator_enabled: bool = False,
    consolidator_port: Optional[int] = 5558,
277
) -> Optional[KvEventPublisher]:
Yan Ru Pei's avatar
Yan Ru Pei committed
278
    """
Yan Ru Pei's avatar
Yan Ru Pei committed
279
280
    Set up KV event publishers for prefix caching if enabled.
    Creates one publisher per dp_rank since each dp_rank publishes to a different port.
281
282
283
284
285
286
287
288
    Args:
        config: Worker configuration
        component: Component for runtime integration
        generate_endpoint: Endpoint for worker ID
        vllm_config: vLLM configuration
        consolidator_enabled: If True, subscribe to kv eventconsolidator's ZMQ endpoint
        consolidator_port: Port where kv event consolidator publishes (default: 5558)

Yan Ru Pei's avatar
Yan Ru Pei committed
289
    Returns:
290
        List of KvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
Yan Ru Pei's avatar
Yan Ru Pei committed
291
292
293
294
    """
    if not config.engine_args.enable_prefix_caching:
        return None

295
296
297
298
299
    # Skip KV event publishing for decode workers
    if config.is_decode_worker:
        logger.info("Skipping KV event publisher setup for decode worker")
        return None

300
301
302
    if config.engine_args.kv_events_config is None:
        return None

303
304
305
306
307
308
309
    # Check if kv_cache_events are explicitly disabled
    if not config.engine_args.kv_events_config.enable_kv_cache_events:
        logger.info(
            "KV event publishing skipped: enable_kv_cache_events=False in kv_events_config"
        )
        return None

Yan Ru Pei's avatar
Yan Ru Pei committed
310
311
312
313
314
    # Get data_parallel_size to create publishers for all dp_ranks
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    kv_publishers = []

    for dp_rank in range(data_parallel_size):
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        if consolidator_enabled:
            # TODO: Use different port for each dp_rank once KVBM supports DP
            zmq_endpoint = f"tcp://127.0.0.1:{consolidator_port}"
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to consolidator at {zmq_endpoint}"
            )
        else:
            # Each dp_rank publishes to a different port
            zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
                config.engine_args.kv_events_config.endpoint,
                data_parallel_rank=dp_rank,
            ).replace("*", "127.0.0.1")
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to vLLM at {zmq_endpoint}"
            )
Yan Ru Pei's avatar
Yan Ru Pei committed
330

331
332
        kv_publisher = KvEventPublisher(
            component=component,
Yan Ru Pei's avatar
Yan Ru Pei committed
333
334
            kv_block_size=vllm_config.cache_config.block_size,
            zmq_endpoint=zmq_endpoint,
335
            zmq_topic="",
336
            enable_local_indexer=config.enable_local_indexer,
337
            dp_rank=dp_rank,
Yan Ru Pei's avatar
Yan Ru Pei committed
338
339
        )
        kv_publishers.append(kv_publisher)
Yan Ru Pei's avatar
Yan Ru Pei committed
340

Yan Ru Pei's avatar
Yan Ru Pei committed
341
342
343
        logger.info(
            f"Worker reading KV events for dp_rank={dp_rank} from {zmq_endpoint}"
        )
Yan Ru Pei's avatar
Yan Ru Pei committed
344

Yan Ru Pei's avatar
Yan Ru Pei committed
345
    return kv_publishers if kv_publishers else None
Yan Ru Pei's avatar
Yan Ru Pei committed
346
347


Alec's avatar
Alec committed
348
def setup_vllm_engine(config, stat_logger=None):
349
350
351
    # vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
    # instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
    # ourselves to avoid this and handle cleanup properly.
352
353
354
355
356
357
358
359
    prometheus_temp_dir = None
    if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
        prometheus_temp_dir = tempfile.TemporaryDirectory(prefix="vllm_prometheus_")
        os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_temp_dir.name
        logger.debug(
            f"Created PROMETHEUS_MULTIPROC_DIR at: {os.environ['PROMETHEUS_MULTIPROC_DIR']}"
        )

360
    setup_multiprocess_prometheus()  # call vLLM's library's function to setup multiprocess prometheus
361
362
363
364
    logger.debug(
        f"Prometheus multiproc dir set to: {os.environ.get('PROMETHEUS_MULTIPROC_DIR')}"
    )

365
366
367
368
369
370
371
372
373
374
375
376
377
    # Construct Prometheus gauges AFTER setup_multiprocess_prometheus() so Gauge objects
    # see the correct ValueClass (multiprocess vs in-memory).
    component_gauges = LLMBackendMetrics(
        registry=DYNAMO_COMPONENT_REGISTRY,
        model_name=config.served_model_name or "",
        component_name=config.component or "",
    )

    # If a StatLoggerFactory was provided, give it the gauges so the loggers
    # it creates can publish Prometheus metrics.
    if stat_logger is not None:
        stat_logger.component_gauges = component_gauges

Alec's avatar
Alec committed
378
379
380
381
    os.environ["VLLM_NO_USAGE_STATS"] = "1"  # Avoid internal HTTP requests
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

    engine_args = config.engine_args
382

383
384
385
386
387
    if engine_args.enable_lora:
        if "VLLM_ALLOW_RUNTIME_LORA_UPDATING" not in os.environ:
            os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
        if "VLLM_LORA_MODULES_LOADING_TIMEOUT" not in os.environ:
            os.environ["VLLM_LORA_MODULES_LOADING_TIMEOUT"] = "600"
388
389

    if engine_args.load_format == "gms":
390
        engine_args.worker_cls = "gpu_memory_service.integrations.vllm.worker.GMSWorker"
391

392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    if engine_args.load_format in ("mx-source", "mx-target"):
        try:
            from modelexpress import register_modelexpress_loaders

            # Ensure the ModelExpress server URL env var is set for the model loader
            if config.model_express_url:
                os.environ["MODEL_EXPRESS_URL"] = config.model_express_url
            register_modelexpress_loaders()
            # Use wrapper worker to ensure loaders are registered in spawned worker processes
            engine_args.worker_cls = "modelexpress.vllm_worker.ModelExpressWorker"
        except ImportError as e:
            raise ImportError(
                f"ModelExpress package required for --load-format={engine_args.load_format}. "
                "Install with: pip install modelexpress"
            ) from e

Alec's avatar
Alec committed
408
409
410
411
412
413
414
415
416
    # Load default sampling params from `generation_config.json`
    default_sampling_params = (
        engine_args.create_model_config().get_diff_sampling_param()
    )

    # Taken from build_async_engine_client_from_engine_args()
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

417
418
419
    # Set up consolidator endpoints if KVBM is enabled
    consolidator_endpoints = None
    if config.has_connector("kvbm"):
420
421
422
423
424
425
426
427
428
429
430
431
        try:
            from kvbm.vllm_integration.consolidator_config import (
                get_consolidator_endpoints,
            )

            consolidator_endpoints = get_consolidator_endpoints(vllm_config)
        except Exception as e:
            logger.warning(
                f"KVBM connector is enabled but failed to get consolidator endpoints: {e}. "
                "Continuing without KV event consolidation. "
                "Ensure 'kvbm' package is installed if this feature is needed."
            )
432
433
    vllm_config.consolidator_endpoints = consolidator_endpoints

Alec's avatar
Alec committed
434
435
436
437
    factory = []
    if stat_logger:
        factory.append(stat_logger)

438
439
    # Time engine initialization
    start_time = time.time()
Alec's avatar
Alec committed
440
441
442
443
    engine_client = AsyncLLM.from_vllm_config(
        vllm_config=vllm_config,
        usage_context=usage_context,
        stat_loggers=factory,
444
        enable_log_requests=engine_args.enable_log_requests,
Alec's avatar
Alec committed
445
446
        disable_log_stats=engine_args.disable_log_stats,
    )
447
448
449
450
    load_time = time.time() - start_time

    # Record model load time
    component_gauges.set_model_load_time(load_time)
451
452

    logger.info(f"VllmWorker for {config.served_model_name} has been initialized")
453

454
455
456
457
458
459
460
    return (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
        component_gauges,
    )
Alec's avatar
Alec committed
461
462


463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
async def register_vllm_model(
    model_input: ModelInput,
    model_type: ModelType,
    generate_endpoint,
    config: Config,
    engine_client: AsyncLLM,
    vllm_config,
):
    """
    Helper function to register a vLLM model with runtime configuration.

    Args:
        model_input: Input type for the model (e.g., ModelInput.Tokens)
        model_type: Type of model (e.g., ModelType.Chat, ModelType.Prefill)
        generate_endpoint: Endpoint to register
        config: Configuration object
        engine_client: vLLM engine client
        vllm_config: vLLM configuration
    """
    runtime_config = ModelRuntimeConfig()

    # Get runtime configuration from vLLM engine
    logging.info(
        f"Getting engine runtime configuration metadata from vLLM engine for {model_type}..."
    )
    runtime_values = get_engine_cache_info(engine_client)
    runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
    runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
    runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
492
493
494
495
    # Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
    runtime_config.enable_local_indexer = (
        config.enable_local_indexer and not config.is_decode_worker
    )
496
497
498

    # Add tool/reasoning parsers for decode models
    if model_type != ModelType.Prefill:
499
500
        runtime_config.tool_call_parser = config.dyn_tool_call_parser
        runtime_config.reasoning_parser = config.dyn_reasoning_parser
501
502
503
504
505

    # Get data_parallel_size from vllm_config (defaults to 1)
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    runtime_config.data_parallel_size = data_parallel_size

506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    # Configure media decoder for frontend image decoding when enabled
    # This enables frontend to decode images and transfer via NIXL RDMA
    media_decoder = None
    media_fetcher = None
    if config.frontend_decoding:
        if not MEDIA_DECODER_AVAILABLE:
            raise RuntimeError(
                "--frontend-decoding requires MediaDecoder support. "
                "Ensure dynamo.llm module includes MediaDecoder and MediaFetcher."
            )
        media_decoder = MediaDecoder()
        media_decoder.enable_image({"limits": {"max_alloc": 128 * 1024 * 1024}})
        # media_decoder.enable_video({})

        media_fetcher = MediaFetcher()
        media_fetcher.timeout_ms(30000)
522
        media_fetcher.allow_direct_port(True)
523

524
    await register_model(
525
526
527
528
529
        model_input,
        model_type,
        generate_endpoint,
        config.model,
        config.served_model_name,
530
        kv_cache_block_size=runtime_values["block_size"],
531
532
        runtime_config=runtime_config,
        custom_template_path=config.custom_jinja_template,
533
534
        media_decoder=media_decoder,
        media_fetcher=media_fetcher,
535
536
537
    )


538
async def init_prefill(
539
540
541
542
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    pre_created_engine=None,
543
):
Alec's avatar
Alec committed
544
545
546
    """
    Instantiate and serve
    """
547
    component = runtime.namespace(config.namespace).component(config.component)
Alec's avatar
Alec committed
548
549
550
551

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")

552
553
554
555
556
557
558
    # Use pre-created engine if provided (checkpoint mode), otherwise create new
    if pre_created_engine is not None:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
559
            _component_gauges,
560
561
562
563
564
565
566
        ) = pre_created_engine
    else:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
567
            _component_gauges,
568
        ) = setup_vllm_engine(config)
Alec's avatar
Alec committed
569

570
    handler = PrefillWorkerHandler(
571
572
573
574
575
        runtime,
        component,
        engine_client,
        default_sampling_params,
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
576
        enable_multimodal=config.enable_multimodal,
577
578
        generate_endpoint=generate_endpoint,
        config=config,
579
        use_vllm_tokenizer=config.use_vllm_tokenizer,
580
        shutdown_event=shutdown_event,
581
        enable_frontend_decoding=config.frontend_decoding,
582
    )
583
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
584

585
586
587
588
589
590
591
592
593
594
595
596
597
598
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

Yan Ru Pei's avatar
Yan Ru Pei committed
599
    # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
600
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
601
    kv_publishers = setup_kv_event_publisher(
602
603
604
605
606
607
        config,
        component,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
608
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
609
610
    if kv_publishers:
        handler.kv_publishers = kv_publishers
611

612
    setup_metrics_collection(config, generate_endpoint, logger)
613

614
    # Register sleep/wake_up engine routes
615
    runtime.register_engine_route("sleep", handler.sleep)
616
617
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
618

619
620
621
622
623
    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return

624
    # Register prefill model with ModelType.Prefill
625
626
627
628
629
630
631
632
633
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
    await register_vllm_model(
        model_input,
        ModelType.Prefill,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
    )
634

635
636
637
    health_check_payload = VllmPrefillHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
638

Alec's avatar
Alec committed
639
    try:
640
        logger.debug("Starting serve_endpoint for prefill worker")
Alec's avatar
Alec committed
641
        await asyncio.gather(
642
643
644
645
            # for prefill, we want to shutdown the engine after all prefill requests are finished because
            #     (temp reason): we don't support re-routing prefill requests
            #     (long-term reason): prefill engine should pull from a global queue so there is
            #                         only a few in-flight requests that can be quickly finished
646
647
648
            generate_endpoint.serve_endpoint(
                handler.generate,
                graceful_shutdown=True,
649
                # In practice config.served_model_name is always set, but mypy needs the "or" here.
650
651
652
653
654
655
656
657
658
659
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
660
                health_check_payload=health_check_payload,
661
662
            ),
            clear_endpoint.serve_endpoint(
663
                handler.clear_kv_blocks,
664
665
666
667
                metrics_labels=[
                    (prometheus_names.labels.MODEL, config.served_model_name),
                    (prometheus_names.labels.MODEL_NAME, config.served_model_name),
                ],
668
            ),
Alec's avatar
Alec committed
669
        )
670
        logger.debug("serve_endpoint completed for prefill worker")
Alec's avatar
Alec committed
671
672
673
674
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
675
        logger.debug("Cleaning up prefill worker")
Alec's avatar
Alec committed
676
677
678
        handler.cleanup()


679
async def init(
680
681
682
683
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    pre_created_engine=None,
684
):
Alec's avatar
Alec committed
685
686
687
688
    """
    Instantiate and serve
    """

689
    component = runtime.namespace(config.namespace).component(config.component)
Alec's avatar
Alec committed
690
691
692

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")
693
694
695
    load_lora_endpoint = component.endpoint("load_lora")
    unload_lora_endpoint = component.endpoint("unload_lora")
    list_loras_endpoint = component.endpoint("list_loras")
Alec's avatar
Alec committed
696

697
    model_name = config.served_model_name or config.model
698
699
700
701
702
703
704
705

    # Use pre-created engine if provided (checkpoint mode), otherwise create new
    if pre_created_engine is not None:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
706
            component_gauges,
707
        ) = pre_created_engine
708
709
710
711
712
713
714
        # Factory is created after unpack so component_gauges is available
        factory = StatLoggerFactory(
            component,
            component_gauges=component_gauges,
            dp_rank=config.engine_args.data_parallel_rank or 0,
            metrics_labels=[("model", model_name)],
        )
715
    else:
716
717
718
719
720
721
722
723
        # Factory is created without component_gauges; setup_vllm_engine() will
        # create the gauges after setup_multiprocess_prometheus() and set them
        # on the factory before vLLM calls create_stat_logger().
        factory = StatLoggerFactory(
            component,
            dp_rank=config.engine_args.data_parallel_rank or 0,
            metrics_labels=[("model", model_name)],
        )
724
725
726
727
728
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
729
            component_gauges,
730
        ) = setup_vllm_engine(config, factory)
Alec's avatar
Alec committed
731

732
    # TODO Hack to get data, move this to registering in TBD
Alec's avatar
Alec committed
733
734
735
736
    factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
    factory.init_publish()

    handler = DecodeWorkerHandler(
737
738
739
740
        runtime,
        component,
        engine_client,
        default_sampling_params,
741
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
742
        enable_multimodal=config.enable_multimodal,
743
744
        generate_endpoint=generate_endpoint,
        config=config,
745
        use_vllm_tokenizer=config.use_vllm_tokenizer,
746
        shutdown_event=shutdown_event,
747
        enable_frontend_decoding=config.frontend_decoding,
Alec's avatar
Alec committed
748
    )
749
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
750

751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

    # Set up KV event publisher for prefix caching if enabled
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
767
    kv_publishers = setup_kv_event_publisher(
768
769
770
771
772
773
        config,
        component,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
Yan Ru Pei's avatar
Yan Ru Pei committed
774
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
775
776
    if kv_publishers:
        handler.kv_publishers = kv_publishers
777

778
    setup_metrics_collection(config, generate_endpoint, logger)
779

780
    # Register sleep/wake_up engine routes
781
    runtime.register_engine_route("sleep", handler.sleep)
782
783
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
784

785
786
787
788
    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return
789

790
791
792
    # Parse endpoint types from --endpoint-types flag
    model_type = parse_endpoint_types(config.endpoint_types)
    logger.info(f"Registering model with endpoint types: {config.endpoint_types}")
793

794
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
795

796
    # Warn if custom template provided but chat endpoint not enabled
797
    if config.custom_jinja_template and "chat" not in config.endpoint_types:
798
799
800
        logger.warning(
            "Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
            "The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
801
802
        )

803
804
805
806
807
808
809
810
811
    await register_vllm_model(
        model_input,
        model_type,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
    )

812
813
814
    health_check_payload = VllmHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
815

Alec's avatar
Alec committed
816
    try:
817
        logger.debug("Starting serve_endpoint for decode worker")
Alec's avatar
Alec committed
818
        await asyncio.gather(
819
820
            # for decode, we want to transfer the in-flight requests to other decode engines,
            # because waiting them to finish can take a long time for long OSLs
821
822
            generate_endpoint.serve_endpoint(
                handler.generate,
823
                graceful_shutdown=True,
824
825
826
827
828
829
830
831
832
833
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
834
                health_check_payload=health_check_payload,
835
836
            ),
            clear_endpoint.serve_endpoint(
837
                handler.clear_kv_blocks,
838
839
840
841
842
843
844
845
846
847
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
848
            ),
849
850
            load_lora_endpoint.serve_endpoint(
                handler.load_lora,
851
852
853
854
855
856
857
858
859
860
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
861
862
863
            ),
            unload_lora_endpoint.serve_endpoint(
                handler.unload_lora,
864
865
866
867
868
869
870
871
872
873
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
874
875
876
            ),
            list_loras_endpoint.serve_endpoint(
                handler.list_loras,
877
878
879
880
881
882
883
884
885
886
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
887
            ),
Alec's avatar
Alec committed
888
        )
889
        logger.debug("serve_endpoint completed for decode worker")
Alec's avatar
Alec committed
890
891
892
893
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
894
        logger.debug("Cleaning up decode worker")
Alec's avatar
Alec committed
895
896
897
898
        # Cleanup background tasks
        handler.cleanup()


899
900
901
902
903
904
905
def get_engine_cache_info(engine: AsyncLLM):
    """Retrieve cache configuration information from [`AsyncLLM`] engine."""

    try:
        # Get values directly from vllm_config instead of collective_rpc
        cache_values = {
            "num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks,
906
            "block_size": engine.vllm_config.cache_config.block_size,
907
908
909
910
911
912
913
914
915
916
917
        }

        scheduler_values = {
            "max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs,
            "max_num_batched_tokens": engine.vllm_config.scheduler_config.max_num_batched_tokens,
        }

        logging.info(f"Cache config values: {cache_values}")
        logging.info(f"Scheduler config values: {scheduler_values}")
        return {
            "num_gpu_blocks": cache_values["num_gpu_blocks"],
918
            "block_size": cache_values["block_size"],
919
920
921
922
923
924
925
926
            "max_num_seqs": scheduler_values["max_num_seqs"],
            "max_num_batched_tokens": scheduler_values["max_num_batched_tokens"],
        }
    except Exception as e:
        logging.error(f"Failed to get configuration values from vLLM config: {e}")
        raise


927
928
929
async def init_multimodal_encode_worker(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
930
931
932
933
934
935
936
937
938
939
940
941
942
943
    """Initialize multimodal encode worker component"""
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)

    handler = EncodeWorkerHandler(
        config.engine_args,
    )
    await handler.async_init(runtime)
    logger.info("Starting to serve the encode worker endpoint...")

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
944
945
946
947
948
                handler.generate,
                metrics_labels=[
                    (prometheus_names.labels.MODEL, config.model),
                    (prometheus_names.labels.MODEL_NAME, config.model),
                ],
949
950
951
            ),
        )
    except Exception as e:
952
953
954
955
956
957
        logger.error(f"Failed to serve encode worker endpoint: {e}")
        raise
    finally:
        handler.cleanup()


958
async def init_multimodal_worker(
959
960
961
962
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    pre_created_engine=None,
963
):
964
965
966
    """
    Initialize multimodal worker component.

967
968
969
970
971
972
973
974
    Supports three modes:
    1. --multimodal-worker: Prefill+decode worker for multimodal LLM; can route
       to a separate encoder (--route-to-encoder) for embeddings. Runs
       aggregated (P+D) or disaggregated (P→D).
    2. --multimodal-decode-worker: Decode-only worker in disaggregated (P→D)
       mode.
    3. --multimodal-encode-prefill-worker: Unified encode+prefill+decode in one
       worker for models with integrated image encoding (e.g., Llama 4).
975
    """
976
977
978
979
980
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")

981
982
983
984
985
986
987
    # Use pre-created engine if provided (checkpoint mode), otherwise create new
    if pre_created_engine is not None:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
988
            _component_gauges,
989
990
991
992
993
994
995
        ) = pre_created_engine
    else:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
996
            _component_gauges,
997
        ) = setup_vllm_engine(config)
998

999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
    # Set up encode worker client when routing to encoder is enabled
    # (PD worker handles encode routing directly instead of a separate processor)
    encode_worker_client = None
    if config.route_to_encoder:
        encode_worker_client = (
            await runtime.namespace(config.namespace)
            .component("encoder")
            .endpoint("generate")
            .client()
        )
        logger.info("Waiting for Encoder Worker Instances ...")
        await encode_worker_client.wait_for_instances()
        logger.info("Connected to encoder workers")

Ayush Agarwal's avatar
Ayush Agarwal committed
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
    # Set up decode worker client for disaggregated mode
    decode_worker_client = None
    if config.is_prefill_worker:
        # Prefill worker needs to connect to decode worker
        decode_worker_client = (
            await runtime.namespace(config.namespace)
            .component("decoder")
            .endpoint("generate")
            .client()
        )
        await decode_worker_client.wait_for_instances()
        logger.info("Connected to decode worker for disaggregated mode")
1025

Ayush Agarwal's avatar
Ayush Agarwal committed
1026
1027
1028
    # Choose handler based on worker type
    if config.multimodal_decode_worker:
        handler = MultimodalDecodeWorkerHandler(
1029
            runtime, component, engine_client, config, shutdown_event
Ayush Agarwal's avatar
Ayush Agarwal committed
1030
1031
1032
        )
    else:
        handler = MultimodalPDWorkerHandler(
1033
1034
1035
1036
            runtime,
            component,
            engine_client,
            config,
1037
            encode_worker_client,
1038
1039
            decode_worker_client,
            shutdown_event,
Ayush Agarwal's avatar
Ayush Agarwal committed
1040
        )
1041
    handler.add_temp_dir(prometheus_temp_dir)
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051

    await handler.async_init(runtime)

    # Set up KV event publisher for prefix caching if enabled
    kv_publisher = setup_kv_event_publisher(
        config, component, generate_endpoint, vllm_config
    )
    if kv_publisher:
        handler.kv_publisher = kv_publisher

1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
    # Register model with the frontend so it can route requests
    model_type = parse_endpoint_types(config.endpoint_types)
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
    await register_vllm_model(
        model_input,
        model_type,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
    )

1064
    metrics_labels = [
1065
1066
        (prometheus_names.labels.MODEL, config.served_model_name or config.model),
        (prometheus_names.labels.MODEL_NAME, config.served_model_name or config.model),
1067
    ]
1068
1069
1070
    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
1071
1072
                handler.generate,
                metrics_labels=metrics_labels,
1073
1074
            ),
            clear_endpoint.serve_endpoint(
1075
1076
                handler.clear_kv_blocks,
                metrics_labels=metrics_labels,
1077
1078
1079
1080
1081
1082
1083
1084
1085
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
        handler.cleanup()


1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
async def init_omni(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
    """
    Initialize Omni worker for text-to-text generation using vLLM-Omni orchestrator.

    Uses vLLM-Omni's Omni class for single-stage text generation pipeline.
    For now, supports text-to-text only (no multimodal).
    """
    # Lazy import to avoid loading vllm-omni unless explicitly needed
    from dynamo.vllm.omni import OmniHandler

    component = runtime.namespace(config.namespace).component(config.component)
    generate_endpoint = component.endpoint(config.endpoint)

    # Initialize OmniHandler with Omni orchestrator
    handler = OmniHandler(
        runtime=runtime,
        component=component,
        config=config,
1106
        default_sampling_params={},
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
        shutdown_event=shutdown_event,
    )

    logger.info(f"Omni worker initialized for model: {config.model}")

    # Set up metrics collection for vLLM and LMCache metrics
    setup_metrics_collection(config, generate_endpoint, logger)

    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return

    # TODO: extend for multi-stage pipelines
1121
1122
1123
1124
    model_type = get_output_modalities(config.output_modalities, config.model)
    if model_type is None:
        # Default to Images
        model_type = ModelType.Images
1125
    await register_model(
1126
        ModelInput.Text,
1127
        model_type,
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
    )

    logger.info("Starting to serve Omni worker endpoint...")

    health_check_payload = (
        await VllmOmniHealthCheckPayload.create(handler.engine_client)
    ).to_dict()

    try:
        await generate_endpoint.serve_endpoint(
            handler.generate,
            graceful_shutdown=True,
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
            metrics_labels=[
                (
                    prometheus_names.labels.MODEL,
                    config.served_model_name or config.model,
                ),
                (
                    prometheus_names.labels.MODEL_NAME,
                    config.served_model_name or config.model,
                ),
            ],
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
            health_check_payload=health_check_payload,
        )
    except Exception as e:
        logger.error(f"Failed to serve Omni endpoint: {e}")
        raise
    finally:
        logger.debug("Cleaning up Omni worker")
        handler.cleanup()


Alec's avatar
Alec committed
1164
1165
1166
1167
def main():
    uvloop.run(worker())


Alec's avatar
Alec committed
1168
if __name__ == "__main__":
Alec's avatar
Alec committed
1169
    main()