"docs/backends/sglang/multinode-examples.md" did not exist on "ef59ac8d3704968723df9ebc8985d7e0d0c44bcc"
main.py 48.4 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Alec's avatar
Alec committed
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging
import os
7
import tempfile
8
import time
9
from typing import Optional
Alec's avatar
Alec committed
10
11

import uvloop
12
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
Alec's avatar
Alec committed
13
14
15
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
16
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
Alec's avatar
Alec committed
17

18
from dynamo.common.config_dump import dump_config
19
from dynamo.common.utils.endpoint_types import parse_endpoint_types
20
21
22
23
from dynamo.common.utils.prometheus import (
    LLMBackendMetrics,
    register_engine_metrics_callback,
)
24
from dynamo.common.utils.runtime import create_runtime
Alec's avatar
Alec committed
25
from dynamo.llm import (
26
    KvEventPublisher,
27
    ModelInput,
28
    ModelRuntimeConfig,
Alec's avatar
Alec committed
29
30
    ModelType,
    ZmqKvEventPublisherConfig,
31
    fetch_llm,
Alec's avatar
Alec committed
32
33
    register_llm,
)
34
35
36
37
38
39
40
41
42
43
44

# Optional imports for frontend decoding support
try:
    from dynamo.llm import MediaDecoder, MediaFetcher

    MEDIA_DECODER_AVAILABLE = True
except ImportError:
    MediaDecoder = None
    MediaFetcher = None
    MEDIA_DECODER_AVAILABLE = False

45
from dynamo.runtime import DistributedRuntime
Alec's avatar
Alec committed
46
from dynamo.runtime.logging import configure_dynamo_logging
47
from dynamo.vllm.multimodal_handlers import (
48
    ECProcessorHandler,
49
    EncodeWorkerHandler,
Ayush Agarwal's avatar
Ayush Agarwal committed
50
    MultimodalDecodeWorkerHandler,
51
    MultimodalPDWorkerHandler,
GuanLuo's avatar
GuanLuo committed
52
    PreprocessedHandler,
53
    VLLMEncodeWorkerHandler,
54
)
55
from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config
Alec's avatar
Alec committed
56

57
from .args import Config, overwrite_args, parse_args
Alec's avatar
Alec committed
58
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
59
60
61
62
63
from .health_check import (
    VllmHealthCheckPayload,
    VllmOmniHealthCheckPayload,
    VllmPrefillHealthCheckPayload,
)
64
from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
Alec's avatar
Alec committed
65

Alec's avatar
Alec committed
66
67
68
69
configure_dynamo_logging()
logger = logging.getLogger(__name__)


70
71
72
73
74
75
76
77
78
79
80
81
82
async def _handle_non_leader_node(dp_rank: int) -> None:
    """
    Handle non-leader node (data_parallel_rank >= 1) in multi-node deployments.
    Non-leader nodes run vLLM workers but don't serve Dynamo endpoints.
    """
    logger.info(
        f"Non-leader node detected (data_parallel_rank={dp_rank}). "
        "Skipping endpoint serving."
    )
    # Wait indefinitely - process terminated via signal handlers
    await asyncio.Event().wait()


83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
async def await_checkpoint_and_was_restored(signal_file: str) -> bool:
    """
    Wait for checkpoint signal file OR restore marker file.

    In checkpoint creation mode, poll until either:
    1. The signal file exists (checkpoint complete, should exit)
    2. The restore marker file exists (restored by CRIU, should proceed)

    The restore marker file is created by the restore-entrypoint before CRIU restore,
    so the restored process can detect it was restored even though os.environ is
    restored from the checkpoint and doesn't contain new container env vars.

    Args:
        signal_file: Path to the checkpoint signal file

    Returns:
        True if restored (should proceed with registration)
        False if signal file detected (should exit)
    """
    # Get restore marker file path (created by restore entrypoint before CRIU restore)
    restore_marker = os.environ.get("DYN_RESTORE_MARKER_FILE", "/tmp/dynamo-restored")

    logger.info(
        f"CHECKPOINT_READY: Model loaded, ready for container checkpoint. Waiting for signal file: {signal_file} or restore marker file: {restore_marker}"
    )

    while True:
        # Check if we've been restored (marker file created by restore entrypoint)
        if os.path.exists(restore_marker):
            logger.info(
                f"Detected restore from checkpoint (marker file exists: {restore_marker})"
            )
            return True  # Restored - proceed with registration

        # Check if checkpoint is complete (signal file exists)
        if os.path.exists(signal_file):
            logger.info(f"Checkpoint signal file detected: {signal_file}")
            return False  # Checkpoint done - exit

        await asyncio.sleep(1)


125
async def worker():
Alec's avatar
Alec committed
126
127
    config = parse_args()

128
    overwrite_args(config)
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    dump_config(config.dump_config_to, config)

    # Name the model. Use either the full path (vllm and sglang do the same),
    # or the HF name (e.g. "Qwen/Qwen3-0.6B"), depending on cmd line params.
    if not config.served_model_name:
        config.served_model_name = config.engine_args.served_model_name = config.model

    # Check checkpoint-related environment variables EARLY
    signal_file = os.environ.get("DYN_CHECKPOINT_SIGNAL_FILE")
    ready_file = os.environ.get("DYN_CHECKPOINT_READY_FILE")

    is_checkpoint_mode = signal_file is not None

    # EARLY EXIT: Check if checkpoint already exists (before downloading model!)
    if is_checkpoint_mode:
        storage_type = os.environ.get("DYN_CHECKPOINT_STORAGE_TYPE")
        checkpoint_location = os.environ.get("DYN_CHECKPOINT_LOCATION")

        if storage_type == "pvc" and checkpoint_location:
            done_marker = f"{checkpoint_location}/checkpoint.done"

            if os.path.exists(done_marker):
                logger.info(
                    f"Found existing checkpoint at {checkpoint_location}. Storage type: {storage_type}"
                )
                return
            else:
                logger.info(
                    f"Checkpoint not found at: {checkpoint_location}. creating new checkpoint"
                )

    # Download the model if necessary using modelexpress.
    # We want it on disk before we start vllm to avoid downloading from HuggingFace.
    #
    # We don't set `config.engine_args.model` to the local path fetch_llm returns
    # because vllm will send that name to its Ray pipeline-parallel workers, which
    # may not have the local path.
    # vllm will attempt to download the model again, but find it in the HF cache.
    # For non-HF models use a path instead of an HF name, and ensure all workers have
    # that path (ideally via a shared folder).
    if not os.path.exists(config.model):
        await fetch_llm(config.model)

    # CHECKPOINT MODE: Load engine BEFORE runtime creation
    # This allows checkpointing GPU state before runtime connections are established
    pre_created_engine = None
    is_restored = False
    if is_checkpoint_mode:
        logger.info(
            f"Checkpoint mode enabled (DYN_CHECKPOINT_SIGNAL_FILE={signal_file})"
        )

        # CHECKPOINT MODE: Load model, sleep, wait for signal file or restore
        pre_created_engine = setup_vllm_engine(config)
        engine_client = pre_created_engine[0]

        # Put model to sleep before checkpoint (if sleep mode enabled)
        if config.engine_args.enable_sleep_mode:
            logger.info(f"Putting model to sleep (level={config.sleep_mode_level})")
            await engine_client.sleep(level=config.sleep_mode_level)

        # Write ready file to signal that we're ready for checkpointing
        if ready_file:
            with open(ready_file, "w") as f:
                f.write("ready")
            logger.info(f"Wrote checkpoint ready file: {ready_file}")

        # Wait for checkpoint signal file OR restore detection
        is_restored = await await_checkpoint_and_was_restored(signal_file)

        if is_restored:
            # Wake up model and proceed with registration
            if config.engine_args.enable_sleep_mode:
                logger.info("Waking up model after checkpoint restore")
                await engine_client.wake_up()
            logger.info("Proceeding with endpoint registration after restore")
        else:
            # Checkpoint complete, exit
            logger.info("Exiting after checkpoint completion")
            return
209

210
    shutdown_event = asyncio.Event()
211
212
213
214
215
216
    runtime, _ = create_runtime(
        store_kv=config.store_kv,
        request_plane=config.request_plane,
        event_plane=config.event_plane,
        use_kv_events=config.use_kv_events,
        shutdown_event=shutdown_event,
217
218
    )

219
    # Route to appropriate initialization based on config flags
220
    if config.vllm_native_encoder_worker:
221
        await init_vllm_native_encoder(runtime, config, shutdown_event)
222
223
        logger.debug("init_vllm_native_encoder completed")
    elif config.ec_processor:
224
        await init_ec_processor(runtime, config, shutdown_event)
225
226
        logger.debug("init_ec_processor completed")
    elif config.multimodal_processor:
227
        await init_multimodal_processor(runtime, config, shutdown_event)
228
229
        logger.debug("init_multimodal_processor completed")
    elif config.multimodal_encode_worker:
230
        await init_multimodal_encode_worker(runtime, config, shutdown_event)
231
        logger.debug("init_multimodal_encode_worker completed")
Ayush Agarwal's avatar
Ayush Agarwal committed
232
233
234
235
236
    elif (
        config.multimodal_worker
        or config.multimodal_decode_worker
        or config.multimodal_encode_prefill_worker
    ):
237
238
239
        await init_multimodal_worker(
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
240
        logger.debug("init_multimodal_worker completed")
241
242
243
    elif config.omni:
        await init_omni(runtime, config, shutdown_event)
        logger.debug("init_omni completed")
244
    elif config.is_prefill_worker:
245
246
247
        await init_prefill(
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
248
        logger.debug("init_prefill completed")
Alec's avatar
Alec committed
249
    else:
250
251
252
        await init(
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
253
254
255
        logger.debug("init completed")

    logger.debug("Worker function completed, exiting...")
Alec's avatar
Alec committed
256
257


258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def setup_metrics_collection(config: Config, generate_endpoint, logger):
    """Set up metrics collection for vLLM and LMCache metrics.

    In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored:
      1. In-memory: Metric objects in global REGISTRY
      2. On-disk: Metric values in .db files (PROMETHEUS_MULTIPROC_DIR)

    MultiProcessCollector reads from .db files but adding it to REGISTRY can fail
    with "Duplicated timeseries" if PROMETHEUS_MULTIPROC_DIR was set before process
    started (K8s deployments) because metrics are already in REGISTRY.

    Solution: Try adding MultiProcessCollector to REGISTRY. If that fails, use
    separate registry for multiprocess collection and register callbacks to both
    registries to ensure all metrics (vllm, lmcache, dynamo_component) are collected.
    """
    if config.engine_args.disable_log_stats is False:
274
275
276
277
278
279
280
281
282
283
284
        # Register the dedicated dynamo_component registry callback
        # IMPORTANT: We do NOT use MultiProcessCollector for DYNAMO_COMPONENT_REGISTRY
        # because our gauges use in-memory values which work fine for single-process
        # and multi-process (each process has its own gauge with dp_rank label).
        # Using MultiProcessCollector would read from .db files which causes stale
        # values to accumulate across test runs.
        register_engine_metrics_callback(
            endpoint=generate_endpoint,
            registry=DYNAMO_COMPONENT_REGISTRY,
        )

285
286
287
288
289
290
291
292
293
        if os.environ.get("PROMETHEUS_MULTIPROC_DIR"):
            try:
                # MultiProcessCollector reads metrics from .db files in PROMETHEUS_MULTIPROC_DIR
                # Adding it to REGISTRY allows collecting both in-memory and .db file metrics
                multiprocess.MultiProcessCollector(REGISTRY)
                logger.debug("Added MultiProcessCollector to global REGISTRY")
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
294
295
296
297
                    metric_prefix_filters=[
                        "vllm:",
                        "lmcache:",
                    ],
298
299
300
301
302
303
304
305
306
307
308
                )
            except ValueError as e:
                # Conflict: metrics already in REGISTRY, MultiProcessCollector tries to add same metrics from .db files
                # Solution: Use separate registry that ONLY reads from .db files (no in-memory conflicts)
                logger.debug(
                    f"Could not add MultiProcessCollector to REGISTRY ({e}), using separate registry"
                )
                multiproc_registry = CollectorRegistry()
                multiprocess.MultiProcessCollector(multiproc_registry)

                # Register both registries to collect all metrics
309
                # Global REGISTRY has in-memory metrics (vllm)
310
311
312
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
313
                    metric_prefix_filters=["vllm:"],
314
315
316
317
318
                )
                # Multiproc registry has .db file metrics (lmcache, possibly vllm duplicates)
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=multiproc_registry,
319
320
321
322
                    metric_prefix_filters=[
                        "vllm:",
                        "lmcache:",
                    ],
323
324
325
326
327
328
329
330
331
332
                )
        else:
            # No multiprocess mode
            register_engine_metrics_callback(
                endpoint=generate_endpoint,
                registry=REGISTRY,
                metric_prefix_filters=["vllm:", "lmcache:"],
            )


Yan Ru Pei's avatar
Yan Ru Pei committed
333
334
335
336
337
def setup_kv_event_publisher(
    config: Config,
    component,
    generate_endpoint,
    vllm_config,
338
339
    consolidator_enabled: bool = False,
    consolidator_port: Optional[int] = 5558,
340
) -> Optional[KvEventPublisher]:
Yan Ru Pei's avatar
Yan Ru Pei committed
341
    """
Yan Ru Pei's avatar
Yan Ru Pei committed
342
343
    Set up KV event publishers for prefix caching if enabled.
    Creates one publisher per dp_rank since each dp_rank publishes to a different port.
344
345
346
347
348
349
350
351
    Args:
        config: Worker configuration
        component: Component for runtime integration
        generate_endpoint: Endpoint for worker ID
        vllm_config: vLLM configuration
        consolidator_enabled: If True, subscribe to kv eventconsolidator's ZMQ endpoint
        consolidator_port: Port where kv event consolidator publishes (default: 5558)

Yan Ru Pei's avatar
Yan Ru Pei committed
352
    Returns:
353
        List of KvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
Yan Ru Pei's avatar
Yan Ru Pei committed
354
355
356
357
    """
    if not config.engine_args.enable_prefix_caching:
        return None

358
359
360
361
362
    # Skip KV event publishing for decode workers
    if config.is_decode_worker:
        logger.info("Skipping KV event publisher setup for decode worker")
        return None

363
364
365
    if config.engine_args.kv_events_config is None:
        return None

366
367
368
369
370
371
372
    # Check if kv_cache_events are explicitly disabled
    if not config.engine_args.kv_events_config.enable_kv_cache_events:
        logger.info(
            "KV event publishing skipped: enable_kv_cache_events=False in kv_events_config"
        )
        return None

Yan Ru Pei's avatar
Yan Ru Pei committed
373
374
375
376
377
    # Get data_parallel_size to create publishers for all dp_ranks
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    kv_publishers = []

    for dp_rank in range(data_parallel_size):
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        if consolidator_enabled:
            # TODO: Use different port for each dp_rank once KVBM supports DP
            zmq_endpoint = f"tcp://127.0.0.1:{consolidator_port}"
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to consolidator at {zmq_endpoint}"
            )
        else:
            # Each dp_rank publishes to a different port
            zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
                config.engine_args.kv_events_config.endpoint,
                data_parallel_rank=dp_rank,
            ).replace("*", "127.0.0.1")
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to vLLM at {zmq_endpoint}"
            )
Yan Ru Pei's avatar
Yan Ru Pei committed
393
394

        zmq_config = ZmqKvEventPublisherConfig(
395
            worker_id=generate_endpoint.connection_id(),
Yan Ru Pei's avatar
Yan Ru Pei committed
396
397
            kv_block_size=vllm_config.cache_config.block_size,
            zmq_endpoint=zmq_endpoint,
398
            enable_local_indexer=config.enable_local_indexer,
399
            dp_rank=dp_rank,
Yan Ru Pei's avatar
Yan Ru Pei committed
400
        )
401
        kv_publisher = KvEventPublisher(component=component, zmq_config=zmq_config)
Yan Ru Pei's avatar
Yan Ru Pei committed
402
        kv_publishers.append(kv_publisher)
Yan Ru Pei's avatar
Yan Ru Pei committed
403

Yan Ru Pei's avatar
Yan Ru Pei committed
404
405
406
        logger.info(
            f"Worker reading KV events for dp_rank={dp_rank} from {zmq_endpoint}"
        )
Yan Ru Pei's avatar
Yan Ru Pei committed
407

Yan Ru Pei's avatar
Yan Ru Pei committed
408
    return kv_publishers if kv_publishers else None
Yan Ru Pei's avatar
Yan Ru Pei committed
409
410


Alec's avatar
Alec committed
411
def setup_vllm_engine(config, stat_logger=None):
412
413
414
    # vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
    # instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
    # ourselves to avoid this and handle cleanup properly.
415
416
417
418
419
420
421
422
    prometheus_temp_dir = None
    if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
        prometheus_temp_dir = tempfile.TemporaryDirectory(prefix="vllm_prometheus_")
        os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_temp_dir.name
        logger.debug(
            f"Created PROMETHEUS_MULTIPROC_DIR at: {os.environ['PROMETHEUS_MULTIPROC_DIR']}"
        )

423
    setup_multiprocess_prometheus()  # call vLLM's library's function to setup multiprocess prometheus
424
425
426
427
    logger.debug(
        f"Prometheus multiproc dir set to: {os.environ.get('PROMETHEUS_MULTIPROC_DIR')}"
    )

428
429
430
431
432
433
434
435
436
437
438
439
440
    # Construct Prometheus gauges AFTER setup_multiprocess_prometheus() so Gauge objects
    # see the correct ValueClass (multiprocess vs in-memory).
    component_gauges = LLMBackendMetrics(
        registry=DYNAMO_COMPONENT_REGISTRY,
        model_name=config.served_model_name or "",
        component_name=config.component or "",
    )

    # If a StatLoggerFactory was provided, give it the gauges so the loggers
    # it creates can publish Prometheus metrics.
    if stat_logger is not None:
        stat_logger.component_gauges = component_gauges

Alec's avatar
Alec committed
441
442
443
444
    os.environ["VLLM_NO_USAGE_STATS"] = "1"  # Avoid internal HTTP requests
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

    engine_args = config.engine_args
445

446
447
448
449
450
    if engine_args.enable_lora:
        if "VLLM_ALLOW_RUNTIME_LORA_UPDATING" not in os.environ:
            os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
        if "VLLM_LORA_MODULES_LOADING_TIMEOUT" not in os.environ:
            os.environ["VLLM_LORA_MODULES_LOADING_TIMEOUT"] = "600"
451
452

    if engine_args.load_format == "gms":
453
        engine_args.worker_cls = "gpu_memory_service.integrations.vllm.worker.GMSWorker"
454

Alec's avatar
Alec committed
455
456
457
458
459
460
461
462
463
    # Load default sampling params from `generation_config.json`
    default_sampling_params = (
        engine_args.create_model_config().get_diff_sampling_param()
    )

    # Taken from build_async_engine_client_from_engine_args()
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

464
465
466
    # Set up consolidator endpoints if KVBM is enabled
    consolidator_endpoints = None
    if config.has_connector("kvbm"):
467
468
469
470
471
472
473
474
475
476
477
478
        try:
            from kvbm.vllm_integration.consolidator_config import (
                get_consolidator_endpoints,
            )

            consolidator_endpoints = get_consolidator_endpoints(vllm_config)
        except Exception as e:
            logger.warning(
                f"KVBM connector is enabled but failed to get consolidator endpoints: {e}. "
                "Continuing without KV event consolidation. "
                "Ensure 'kvbm' package is installed if this feature is needed."
            )
479
480
    vllm_config.consolidator_endpoints = consolidator_endpoints

Alec's avatar
Alec committed
481
482
483
484
    factory = []
    if stat_logger:
        factory.append(stat_logger)

485
486
    # Time engine initialization
    start_time = time.time()
Alec's avatar
Alec committed
487
488
489
490
    engine_client = AsyncLLM.from_vllm_config(
        vllm_config=vllm_config,
        usage_context=usage_context,
        stat_loggers=factory,
491
        enable_log_requests=engine_args.enable_log_requests,
Alec's avatar
Alec committed
492
493
        disable_log_stats=engine_args.disable_log_stats,
    )
494
495
496
497
    load_time = time.time() - start_time

    # Record model load time
    component_gauges.set_model_load_time(load_time)
498
499

    logger.info(f"VllmWorker for {config.served_model_name} has been initialized")
500

501
502
503
504
505
506
507
    return (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
        component_gauges,
    )
Alec's avatar
Alec committed
508
509


510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
async def register_vllm_model(
    model_input: ModelInput,
    model_type: ModelType,
    generate_endpoint,
    config: Config,
    engine_client: AsyncLLM,
    vllm_config,
):
    """
    Helper function to register a vLLM model with runtime configuration.

    Args:
        model_input: Input type for the model (e.g., ModelInput.Tokens)
        model_type: Type of model (e.g., ModelType.Chat, ModelType.Prefill)
        generate_endpoint: Endpoint to register
        config: Configuration object
        engine_client: vLLM engine client
        vllm_config: vLLM configuration
    """
    runtime_config = ModelRuntimeConfig()

    # Get runtime configuration from vLLM engine
    logging.info(
        f"Getting engine runtime configuration metadata from vLLM engine for {model_type}..."
    )
    runtime_values = get_engine_cache_info(engine_client)
    runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
    runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
    runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
539
540
541
542
    # Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
    runtime_config.enable_local_indexer = (
        config.enable_local_indexer and not config.is_decode_worker
    )
543
544
545
546
547
548
549
550
551
552

    # Add tool/reasoning parsers for decode models
    if model_type != ModelType.Prefill:
        runtime_config.tool_call_parser = config.tool_call_parser
        runtime_config.reasoning_parser = config.reasoning_parser

    # Get data_parallel_size from vllm_config (defaults to 1)
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    runtime_config.data_parallel_size = data_parallel_size

553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
    # Configure media decoder for frontend image decoding when enabled
    # This enables frontend to decode images and transfer via NIXL RDMA
    media_decoder = None
    media_fetcher = None
    if config.frontend_decoding:
        if not MEDIA_DECODER_AVAILABLE:
            raise RuntimeError(
                "--frontend-decoding requires MediaDecoder support. "
                "Ensure dynamo.llm module includes MediaDecoder and MediaFetcher."
            )
        media_decoder = MediaDecoder()
        media_decoder.enable_image({"limits": {"max_alloc": 128 * 1024 * 1024}})
        # media_decoder.enable_video({})

        media_fetcher = MediaFetcher()
        media_fetcher.timeout_ms(30000)
569
        media_fetcher.allow_direct_port(True)
570

571
572
573
574
575
576
    await register_llm(
        model_input,
        model_type,
        generate_endpoint,
        config.model,
        config.served_model_name,
577
        kv_cache_block_size=runtime_values["block_size"],
578
579
        runtime_config=runtime_config,
        custom_template_path=config.custom_jinja_template,
580
581
        media_decoder=media_decoder,
        media_fetcher=media_fetcher,
582
583
584
    )


585
async def init_prefill(
586
587
588
589
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    pre_created_engine=None,
590
):
Alec's avatar
Alec committed
591
592
593
    """
    Instantiate and serve
    """
594
    component = runtime.namespace(config.namespace).component(config.component)
Alec's avatar
Alec committed
595
596
597
598

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")

599
600
601
602
603
604
605
    # Use pre-created engine if provided (checkpoint mode), otherwise create new
    if pre_created_engine is not None:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
606
            _component_gauges,
607
608
609
610
611
612
613
        ) = pre_created_engine
    else:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
614
            _component_gauges,
615
        ) = setup_vllm_engine(config)
Alec's avatar
Alec committed
616

617
    handler = PrefillWorkerHandler(
618
619
620
621
622
        runtime,
        component,
        engine_client,
        default_sampling_params,
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
623
        enable_multimodal=config.enable_multimodal,
624
625
        generate_endpoint=generate_endpoint,
        config=config,
626
        use_vllm_tokenizer=config.use_vllm_tokenizer,
627
        shutdown_event=shutdown_event,
628
        enable_frontend_decoding=config.frontend_decoding,
629
    )
630
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
631

632
633
634
635
636
637
638
639
640
641
642
643
644
645
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

Yan Ru Pei's avatar
Yan Ru Pei committed
646
    # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
647
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
648
    kv_publishers = setup_kv_event_publisher(
649
650
651
652
653
654
        config,
        component,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
655
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
656
657
    if kv_publishers:
        handler.kv_publishers = kv_publishers
658

659
    setup_metrics_collection(config, generate_endpoint, logger)
660

661
    # Register sleep/wake_up engine routes
662
    runtime.register_engine_route("sleep", handler.sleep)
663
664
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
665

666
667
668
669
670
    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return

671
    # Register prefill model with ModelType.Prefill
672
673
674
675
676
677
678
679
680
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
    await register_vllm_model(
        model_input,
        ModelType.Prefill,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
    )
681

682
683
684
    health_check_payload = VllmPrefillHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
685

Alec's avatar
Alec committed
686
    try:
687
        logger.debug("Starting serve_endpoint for prefill worker")
Alec's avatar
Alec committed
688
        await asyncio.gather(
689
690
691
692
            # for prefill, we want to shutdown the engine after all prefill requests are finished because
            #     (temp reason): we don't support re-routing prefill requests
            #     (long-term reason): prefill engine should pull from a global queue so there is
            #                         only a few in-flight requests that can be quickly finished
693
694
695
            generate_endpoint.serve_endpoint(
                handler.generate,
                graceful_shutdown=True,
696
697
                # In practice config.served_model_name is always set, but mypy needs the "or" here.
                metrics_labels=[("model", config.served_model_name or config.model)],
698
                health_check_payload=health_check_payload,
699
700
            ),
            clear_endpoint.serve_endpoint(
701
702
                handler.clear_kv_blocks,
                metrics_labels=[("model", config.served_model_name)],
703
            ),
Alec's avatar
Alec committed
704
        )
705
        logger.debug("serve_endpoint completed for prefill worker")
Alec's avatar
Alec committed
706
707
708
709
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
710
        logger.debug("Cleaning up prefill worker")
Alec's avatar
Alec committed
711
712
713
        handler.cleanup()


714
async def init(
715
716
717
718
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    pre_created_engine=None,
719
):
Alec's avatar
Alec committed
720
721
722
723
    """
    Instantiate and serve
    """

724
    component = runtime.namespace(config.namespace).component(config.component)
Alec's avatar
Alec committed
725
726
727

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")
728
729
730
    load_lora_endpoint = component.endpoint("load_lora")
    unload_lora_endpoint = component.endpoint("unload_lora")
    list_loras_endpoint = component.endpoint("list_loras")
Alec's avatar
Alec committed
731

732
    model_name = config.served_model_name or config.model
733
734
735
736
737
738
739
740

    # Use pre-created engine if provided (checkpoint mode), otherwise create new
    if pre_created_engine is not None:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
741
            component_gauges,
742
        ) = pre_created_engine
743
744
745
746
747
748
749
        # Factory is created after unpack so component_gauges is available
        factory = StatLoggerFactory(
            component,
            component_gauges=component_gauges,
            dp_rank=config.engine_args.data_parallel_rank or 0,
            metrics_labels=[("model", model_name)],
        )
750
    else:
751
752
753
754
755
756
757
758
        # Factory is created without component_gauges; setup_vllm_engine() will
        # create the gauges after setup_multiprocess_prometheus() and set them
        # on the factory before vLLM calls create_stat_logger().
        factory = StatLoggerFactory(
            component,
            dp_rank=config.engine_args.data_parallel_rank or 0,
            metrics_labels=[("model", model_name)],
        )
759
760
761
762
763
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
764
            component_gauges,
765
        ) = setup_vllm_engine(config, factory)
Alec's avatar
Alec committed
766

767
    # TODO Hack to get data, move this to registering in TBD
Alec's avatar
Alec committed
768
769
770
771
    factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
    factory.init_publish()

    handler = DecodeWorkerHandler(
772
773
774
775
        runtime,
        component,
        engine_client,
        default_sampling_params,
776
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
777
        enable_multimodal=config.enable_multimodal,
778
779
        generate_endpoint=generate_endpoint,
        config=config,
780
        use_vllm_tokenizer=config.use_vllm_tokenizer,
781
        shutdown_event=shutdown_event,
782
        enable_frontend_decoding=config.frontend_decoding,
Alec's avatar
Alec committed
783
    )
784
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
785

786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

    # Set up KV event publisher for prefix caching if enabled
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
802
    kv_publishers = setup_kv_event_publisher(
803
804
805
806
807
808
        config,
        component,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
Yan Ru Pei's avatar
Yan Ru Pei committed
809
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
810
811
    if kv_publishers:
        handler.kv_publishers = kv_publishers
812

813
    setup_metrics_collection(config, generate_endpoint, logger)
814

815
    # Register sleep/wake_up engine routes
816
    runtime.register_engine_route("sleep", handler.sleep)
817
818
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
819

820
821
822
823
    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return
824

825
826
827
    # Parse endpoint types from --dyn-endpoint-types flag
    model_type = parse_endpoint_types(config.dyn_endpoint_types)
    logger.info(f"Registering model with endpoint types: {config.dyn_endpoint_types}")
828

829
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
830

831
832
833
834
835
    # Warn if custom template provided but chat endpoint not enabled
    if config.custom_jinja_template and "chat" not in config.dyn_endpoint_types:
        logger.warning(
            "Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
            "The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
836
837
        )

838
839
840
841
842
843
844
845
846
    await register_vllm_model(
        model_input,
        model_type,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
    )

847
848
849
    health_check_payload = VllmHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
850

Alec's avatar
Alec committed
851
    try:
852
        logger.debug("Starting serve_endpoint for decode worker")
Alec's avatar
Alec committed
853
        await asyncio.gather(
854
855
            # for decode, we want to transfer the in-flight requests to other decode engines,
            # because waiting them to finish can take a long time for long OSLs
856
857
            generate_endpoint.serve_endpoint(
                handler.generate,
858
                graceful_shutdown=True,
859
                metrics_labels=[("model", config.served_model_name or config.model)],
860
                health_check_payload=health_check_payload,
861
862
            ),
            clear_endpoint.serve_endpoint(
863
864
                handler.clear_kv_blocks,
                metrics_labels=[("model", config.served_model_name or config.model)],
865
            ),
866
867
868
869
870
871
872
873
874
875
876
877
            load_lora_endpoint.serve_endpoint(
                handler.load_lora,
                metrics_labels=[("model", config.served_model_name or config.model)],
            ),
            unload_lora_endpoint.serve_endpoint(
                handler.unload_lora,
                metrics_labels=[("model", config.served_model_name or config.model)],
            ),
            list_loras_endpoint.serve_endpoint(
                handler.list_loras,
                metrics_labels=[("model", config.served_model_name or config.model)],
            ),
Alec's avatar
Alec committed
878
        )
879
        logger.debug("serve_endpoint completed for decode worker")
Alec's avatar
Alec committed
880
881
882
883
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
884
        logger.debug("Cleaning up decode worker")
Alec's avatar
Alec committed
885
886
887
888
        # Cleanup background tasks
        handler.cleanup()


889
890
891
892
893
894
895
def get_engine_cache_info(engine: AsyncLLM):
    """Retrieve cache configuration information from [`AsyncLLM`] engine."""

    try:
        # Get values directly from vllm_config instead of collective_rpc
        cache_values = {
            "num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks,
896
            "block_size": engine.vllm_config.cache_config.block_size,
897
898
899
900
901
902
903
904
905
906
907
        }

        scheduler_values = {
            "max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs,
            "max_num_batched_tokens": engine.vllm_config.scheduler_config.max_num_batched_tokens,
        }

        logging.info(f"Cache config values: {cache_values}")
        logging.info(f"Scheduler config values: {scheduler_values}")
        return {
            "num_gpu_blocks": cache_values["num_gpu_blocks"],
908
            "block_size": cache_values["block_size"],
909
910
911
912
913
914
915
916
            "max_num_seqs": scheduler_values["max_num_seqs"],
            "max_num_batched_tokens": scheduler_values["max_num_batched_tokens"],
        }
    except Exception as e:
        logging.error(f"Failed to get configuration values from vLLM config: {e}")
        raise


917
918
919
async def init_multimodal_processor(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
920
921
922
923
924
925
926
927
928
929
930
931
932
    """Initialize multimodal processor component"""
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)

    # Get encode worker client
    encode_worker_client = (
        await runtime.namespace(config.namespace)
        .component("encoder")
        .endpoint("generate")
        .client()
    )

GuanLuo's avatar
GuanLuo committed
933
934
935
936
937
938
    pd_worker_client = (
        await runtime.namespace(config.namespace)
        .component("backend")
        .endpoint("generate")
        .client()
    )
939

GuanLuo's avatar
GuanLuo committed
940
    handler = PreprocessedHandler(
941
942
        config.engine_args,
        encode_worker_client,
GuanLuo's avatar
GuanLuo committed
943
        pd_worker_client,
944
945
946
947
948
949
950
    )

    logger.info("Waiting for Encoder Worker Instances ...")
    await encode_worker_client.wait_for_instances()

    # Register the endpoint as entrypoint to a model
    await register_llm(
GuanLuo's avatar
GuanLuo committed
951
        ModelInput.Tokens,
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
        ModelType.Chat,
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
    )

    logger.info("Starting to serve the processor endpoint...")

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
                handler.generate, metrics_labels=[("model", config.model)]
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
        handler.cleanup()


974
975
976
async def init_multimodal_encode_worker(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
    """Initialize multimodal encode worker component"""
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)

    # Get PD worker client
    # In multimodal mode, the PD worker always registers as "backend"
    # (even in disaggregated mode with prefill/decode split, we still connect to "backend")
    pd_worker_client = (
        await runtime.namespace(config.namespace)
        .component("backend")
        .endpoint("generate")
        .client()
    )

    handler = EncodeWorkerHandler(
        config.engine_args,
        pd_worker_client,
    )
    await handler.async_init(runtime)
    logger.info("Waiting for PD Worker Instances ...")
    await pd_worker_client.wait_for_instances()
    logger.info("Starting to serve the encode worker endpoint...")

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
                handler.generate, metrics_labels=[("model", config.model)]
            ),
        )
    except Exception as e:
1008
1009
1010
1011
1012
1013
        logger.error(f"Failed to serve encode worker endpoint: {e}")
        raise
    finally:
        handler.cleanup()


1014
1015
1016
async def init_vllm_native_encoder(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
    """
    Initialize vLLM-native encoder worker component (ECConnector mode).
    In this mode, vLLM handles encoder execution, caching, and storage automatically.
    """
    # Create component and endpoint
    component = runtime.namespace(config.namespace).component(config.component)
    generate_endpoint = component.endpoint(config.endpoint)

    # Configure ECTransferConfig for producer role
    instance_id = 0
    engine_id = f"{config.namespace}.{config.component}.encoder.{instance_id}"

    # Configure encoder with producer role, it will be responsible for creating embeddings and storing them in the shared storage
    ec_transfer_config = create_ec_transfer_config(
        engine_id=engine_id,
        ec_role="ec_producer",
        ec_connector_backend=config.ec_connector_backend,
        ec_storage_path=config.ec_storage_path,
        ec_extra_config=config.ec_extra_config,
    )

    # Set ECTransferConfig on engine args
    config.engine_args.ec_transfer_config = ec_transfer_config

    # Setup vLLM engine
    (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
1047
        _component_gauges,
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
    ) = setup_vllm_engine(config)

    # Initialize vLLM Native Encoder Worker Handler
    handler = VLLMEncodeWorkerHandler(
        runtime,
        component,
        engine_client,
        config,
    )
    handler.add_temp_dir(prometheus_temp_dir)

    # 5. No async init needed - vLLM handles everything
    # await handler.async_init(runtime)  # Not needed for ECConnector mode

    logger.info("Starting to serve vLLM-native encoder endpoint...")

    # 6. Serve endpoint
    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
                handler.generate, metrics_labels=[("model", config.model)]
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve vLLM-native encoder endpoint: {e}")
        raise
    finally:
        handler.cleanup()


1078
1079
1080
async def init_ec_processor(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
    """
    Initialize ECConnector processor component.

    Simple processor that routes multimodal requests using ECConnector pattern:
    1. Preprocess request (same as regular processor)
    2. Send multimodal items to encoder workers (stores to shared storage)
    3. Forward preprocessed request to PD worker (loads from shared storage)
    4. Stream response back to client
    """
    # Create component and endpoint
    component = runtime.namespace(config.namespace).component(config.component)
    generate_endpoint = component.endpoint(config.endpoint)

    # Get encoder worker client
    encoder_client = (
        await runtime.namespace(config.namespace)
        .component("encoder")
        .endpoint("generate")
        .client()
    )

    # Get PD worker client
    pd_client = (
        await runtime.namespace(config.namespace)
        .component("backend")
        .endpoint("generate")
        .client()
    )

    # Get prompt template from args (must be passed via environment or command line)
    mm_prompt_template = config.mm_prompt_template

    # Create EC processor handler (with preprocessing like regular processor)
    handler = ECProcessorHandler(
        config.engine_args,
        encoder_worker_client=encoder_client,
        pd_worker_client=pd_client,
        prompt_template=mm_prompt_template,
    )

    logger.info("Waiting for encoder and PD worker instances...")
    await encoder_client.wait_for_instances()
    await pd_client.wait_for_instances()

1125
    # Register the endpoint as entrypoint to a model (same as preprocessed_handler)
1126
    await register_llm(
1127
        ModelInput.Tokens,  # Use Rust tokenization for better performance and multi-image support
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
        ModelType.Chat,
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
    )

    logger.info("Starting to serve EC processor endpoint...")

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
                handler.generate, metrics_labels=[("model", config.model)]
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve EC processor endpoint: {e}")
1145
1146
1147
1148
1149
        raise
    finally:
        handler.cleanup()


1150
async def init_multimodal_worker(
1151
1152
1153
1154
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    pre_created_engine=None,
1155
):
1156
1157
1158
1159
1160
1161
    """
    Initialize multimodal worker component.

    Supports two modes:
    1. --multimodal-worker: Receives embeddings from separate encoder
    2. --multimodal-encode-prefill-worker: Handles inline encoding (e.g., Llama 4)
1162

1163
    Both can operate in aggregated (P+D) or disaggregated (P→D) mode.
1164
1165
1166

    When --ec-consumer-mode is enabled, configures as ECConnector consumer
    to load encoder embeddings from shared storage.
1167
    """
1168
1169
1170
1171
1172
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")

1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
    # Configure ECConnector consumer mode if enabled
    if config.ec_consumer_mode:
        logger.info("Configuring as ECConnector consumer for encoder embeddings")
        instance_id = 0
        engine_id = f"{config.namespace}.{config.component}.backend.{instance_id}"

        # The PD Worker just load the embeddings from the shared storage, so it is a consumer
        ec_transfer_config = create_ec_transfer_config(
            engine_id=engine_id,
            ec_role="ec_consumer",
            ec_connector_backend=config.ec_connector_backend,
            ec_storage_path=config.ec_storage_path,
            ec_extra_config=config.ec_extra_config,
        )

        # Set ECTransferConfig on engine args
        config.engine_args.ec_transfer_config = ec_transfer_config
        logger.info(f"Configured as ECConnector consumer with engine_id={engine_id}")

1192
1193
1194
1195
1196
1197
1198
    # Use pre-created engine if provided (checkpoint mode), otherwise create new
    if pre_created_engine is not None:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
1199
            _component_gauges,
1200
1201
1202
1203
1204
1205
1206
        ) = pre_created_engine
    else:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
1207
            _component_gauges,
1208
        ) = setup_vllm_engine(config)
1209

Ayush Agarwal's avatar
Ayush Agarwal committed
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
    # Set up decode worker client for disaggregated mode
    decode_worker_client = None
    if config.is_prefill_worker:
        # Prefill worker needs to connect to decode worker
        decode_worker_client = (
            await runtime.namespace(config.namespace)
            .component("decoder")
            .endpoint("generate")
            .client()
        )
        await decode_worker_client.wait_for_instances()
        logger.info("Connected to decode worker for disaggregated mode")
1222

Ayush Agarwal's avatar
Ayush Agarwal committed
1223
1224
1225
    # Choose handler based on worker type
    if config.multimodal_decode_worker:
        handler = MultimodalDecodeWorkerHandler(
1226
            runtime, component, engine_client, config, shutdown_event
Ayush Agarwal's avatar
Ayush Agarwal committed
1227
1228
1229
        )
    else:
        handler = MultimodalPDWorkerHandler(
1230
1231
1232
1233
1234
1235
            runtime,
            component,
            engine_client,
            config,
            decode_worker_client,
            shutdown_event,
Ayush Agarwal's avatar
Ayush Agarwal committed
1236
        )
1237
    handler.add_temp_dir(prometheus_temp_dir)
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251

    await handler.async_init(runtime)

    # Set up KV event publisher for prefix caching if enabled
    kv_publisher = setup_kv_event_publisher(
        config, component, generate_endpoint, vllm_config
    )
    if kv_publisher:
        handler.kv_publisher = kv_publisher

    metrics_labels = [("model", config.model)]
    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
1252
1253
                handler.generate,
                metrics_labels=metrics_labels,
1254
1255
            ),
            clear_endpoint.serve_endpoint(
1256
1257
                handler.clear_kv_blocks,
                metrics_labels=metrics_labels,
1258
1259
1260
1261
1262
1263
1264
1265
1266
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
        handler.cleanup()


1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
async def init_omni(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
    """
    Initialize Omni worker for text-to-text generation using vLLM-Omni orchestrator.

    Uses vLLM-Omni's Omni class for single-stage text generation pipeline.
    For now, supports text-to-text only (no multimodal).
    """
    # Lazy import to avoid loading vllm-omni unless explicitly needed
    from dynamo.vllm.omni import OmniHandler

    component = runtime.namespace(config.namespace).component(config.component)
    generate_endpoint = component.endpoint(config.endpoint)

    # Initialize OmniHandler with Omni orchestrator
    handler = OmniHandler(
        runtime=runtime,
        component=component,
        config=config,
1287
        default_sampling_params={},
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
        shutdown_event=shutdown_event,
    )

    logger.info(f"Omni worker initialized for model: {config.model}")

    # Set up metrics collection for vLLM and LMCache metrics
    setup_metrics_collection(config, generate_endpoint, logger)

    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return

    # TODO: extend for multi-stage pipelines
    await register_llm(
1303
1304
        ModelInput.Text,
        ModelType.Images,
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
    )

    logger.info("Starting to serve Omni worker endpoint...")

    health_check_payload = (
        await VllmOmniHealthCheckPayload.create(handler.engine_client)
    ).to_dict()

    try:
        await generate_endpoint.serve_endpoint(
            handler.generate,
            graceful_shutdown=True,
            metrics_labels=[("model", config.served_model_name or config.model)],
            health_check_payload=health_check_payload,
        )
    except Exception as e:
        logger.error(f"Failed to serve Omni endpoint: {e}")
        raise
    finally:
        logger.debug("Cleaning up Omni worker")
        handler.cleanup()


Alec's avatar
Alec committed
1332
1333
1334
1335
def main():
    uvloop.run(worker())


Alec's avatar
Alec committed
1336
if __name__ == "__main__":
Alec's avatar
Alec committed
1337
    main()