"examples/backends/vllm/launch/agg_multimodal.sh" did not exist on "93208162753986f9449d3671d6a263dfc4f4381e"
main.py 49.3 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Alec's avatar
Alec committed
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging
import os
7
import tempfile
8
import time
9
from typing import Optional
Alec's avatar
Alec committed
10
11

import uvloop
12
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
Alec's avatar
Alec committed
13
14
15
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
16
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
Alec's avatar
Alec committed
17

18
from dynamo import prometheus_names
19
from dynamo.common.config_dump import dump_config
20
from dynamo.common.utils.endpoint_types import parse_endpoint_types
21
22
23
24
from dynamo.common.utils.prometheus import (
    LLMBackendMetrics,
    register_engine_metrics_callback,
)
25
from dynamo.common.utils.runtime import create_runtime
Alec's avatar
Alec committed
26
from dynamo.llm import (
27
    KvEventPublisher,
28
    ModelInput,
29
    ModelRuntimeConfig,
Alec's avatar
Alec committed
30
    ModelType,
31
    fetch_llm,
Alec's avatar
Alec committed
32
33
    register_llm,
)
34
35
36
37
38
39
40
41
42
43
44

# Optional imports for frontend decoding support
try:
    from dynamo.llm import MediaDecoder, MediaFetcher

    MEDIA_DECODER_AVAILABLE = True
except ImportError:
    MediaDecoder = None
    MediaFetcher = None
    MEDIA_DECODER_AVAILABLE = False

45
from dynamo.runtime import DistributedRuntime
Alec's avatar
Alec committed
46
from dynamo.runtime.logging import configure_dynamo_logging
47
from dynamo.vllm.multimodal_handlers import (
48
    ECProcessorHandler,
49
    EncodeWorkerHandler,
Ayush Agarwal's avatar
Ayush Agarwal committed
50
    MultimodalDecodeWorkerHandler,
51
    MultimodalPDWorkerHandler,
GuanLuo's avatar
GuanLuo committed
52
    PreprocessedHandler,
53
    VLLMEncodeWorkerHandler,
54
)
55
from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config
Alec's avatar
Alec committed
56

57
from .args import Config, parse_args
58
from .chrek import get_checkpoint_config
Alec's avatar
Alec committed
59
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
60
61
62
63
64
from .health_check import (
    VllmHealthCheckPayload,
    VllmOmniHealthCheckPayload,
    VllmPrefillHealthCheckPayload,
)
65
from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
Alec's avatar
Alec committed
66

Alec's avatar
Alec committed
67
68
configure_dynamo_logging()
logger = logging.getLogger(__name__)
69
CHECKPOINT_SLEEP_MODE_LEVEL = 1
Alec's avatar
Alec committed
70
71


72
73
74
75
76
77
78
79
80
81
82
83
84
async def _handle_non_leader_node(dp_rank: int) -> None:
    """
    Handle non-leader node (data_parallel_rank >= 1) in multi-node deployments.
    Non-leader nodes run vLLM workers but don't serve Dynamo endpoints.
    """
    logger.info(
        f"Non-leader node detected (data_parallel_rank={dp_rank}). "
        "Skipping endpoint serving."
    )
    # Wait indefinitely - process terminated via signal handlers
    await asyncio.Event().wait()


85
async def graceful_shutdown(runtime, shutdown_event):
86
    """
87
88
89
90
    Shutdown dynamo distributed runtime.
    The endpoints will be immediately invalidated so no new requests will be accepted.
    For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
    For endpoints served with graceful_shutdown=False, the serving function will return immediately.
91
    """
92
93
94
95
    logging.info("Received shutdown signal, shutting down DistributedRuntime")
    shutdown_event.set()
    runtime.shutdown()
    logging.info("DistributedRuntime shutdown complete")
96
97


98
async def worker():
Alec's avatar
Alec committed
99
100
    config = parse_args()

101
102
103
104
105
106
107
    dump_config(config.dump_config_to, config)

    # Name the model. Use either the full path (vllm and sglang do the same),
    # or the HF name (e.g. "Qwen/Qwen3-0.6B"), depending on cmd line params.
    if not config.served_model_name:
        config.served_model_name = config.engine_args.served_model_name = config.model

108
109
110
111
    # Check checkpoint mode and validate env vars EARLY (fail fast if misconfigured)
    checkpoint_cfg = get_checkpoint_config()
    if checkpoint_cfg and checkpoint_cfg.checkpoint_exists():
        return
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

    # Download the model if necessary using modelexpress.
    # We want it on disk before we start vllm to avoid downloading from HuggingFace.
    #
    # We don't set `config.engine_args.model` to the local path fetch_llm returns
    # because vllm will send that name to its Ray pipeline-parallel workers, which
    # may not have the local path.
    # vllm will attempt to download the model again, but find it in the HF cache.
    # For non-HF models use a path instead of an HF name, and ensure all workers have
    # that path (ideally via a shared folder).
    if not os.path.exists(config.model):
        await fetch_llm(config.model)

    # CHECKPOINT MODE: Load engine BEFORE runtime creation
    # This allows checkpointing GPU state before runtime connections are established
    pre_created_engine = None
128
    if checkpoint_cfg is not None:
129
        logger.info(
130
            f"Checkpoint mode enabled (signal_file={checkpoint_cfg.signal_file})"
131
132
        )

133
134
135
        # Checkpoint mode requires sleep mode — enable before engine init
        config.engine_args.enable_sleep_mode = True

136
137
138
        pre_created_engine = setup_vllm_engine(config)
        engine_client = pre_created_engine[0]

139
140
141
        if not await checkpoint_cfg.run_lifecycle(
            engine_client, CHECKPOINT_SLEEP_MODE_LEVEL
        ):
142
            return
143

144
    shutdown_event = asyncio.Event()
145
146
147
148
149
150
    runtime, _ = create_runtime(
        store_kv=config.store_kv,
        request_plane=config.request_plane,
        event_plane=config.event_plane,
        use_kv_events=config.use_kv_events,
        shutdown_event=shutdown_event,
151
152
    )

153
    # Route to appropriate initialization based on config flags
154
    if config.vllm_native_encoder_worker:
155
        await init_vllm_native_encoder(runtime, config, shutdown_event)
156
157
        logger.debug("init_vllm_native_encoder completed")
    elif config.ec_processor:
158
        await init_ec_processor(runtime, config, shutdown_event)
159
160
        logger.debug("init_ec_processor completed")
    elif config.multimodal_processor:
161
        await init_multimodal_processor(runtime, config, shutdown_event)
162
163
        logger.debug("init_multimodal_processor completed")
    elif config.multimodal_encode_worker:
164
        await init_multimodal_encode_worker(runtime, config, shutdown_event)
165
        logger.debug("init_multimodal_encode_worker completed")
Ayush Agarwal's avatar
Ayush Agarwal committed
166
167
168
169
170
    elif (
        config.multimodal_worker
        or config.multimodal_decode_worker
        or config.multimodal_encode_prefill_worker
    ):
171
172
173
        await init_multimodal_worker(
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
174
        logger.debug("init_multimodal_worker completed")
175
176
177
    elif config.omni:
        await init_omni(runtime, config, shutdown_event)
        logger.debug("init_omni completed")
178
    elif config.is_prefill_worker:
179
180
181
        await init_prefill(
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
182
        logger.debug("init_prefill completed")
Alec's avatar
Alec committed
183
    else:
184
185
186
        await init(
            runtime, config, shutdown_event, pre_created_engine=pre_created_engine
        )
187
188
189
        logger.debug("init completed")

    logger.debug("Worker function completed, exiting...")
Alec's avatar
Alec committed
190
191


192
193
194
195
196
197
198
199
200
201
202
203
204
205
def setup_metrics_collection(config: Config, generate_endpoint, logger):
    """Set up metrics collection for vLLM and LMCache metrics.

    In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored:
      1. In-memory: Metric objects in global REGISTRY
      2. On-disk: Metric values in .db files (PROMETHEUS_MULTIPROC_DIR)

    MultiProcessCollector reads from .db files but adding it to REGISTRY can fail
    with "Duplicated timeseries" if PROMETHEUS_MULTIPROC_DIR was set before process
    started (K8s deployments) because metrics are already in REGISTRY.

    Solution: Try adding MultiProcessCollector to REGISTRY. If that fails, use
    separate registry for multiprocess collection and register callbacks to both
    registries to ensure all metrics (vllm, lmcache, dynamo_component) are collected.
206
207
208
209
210

    Auto-label injection:
        Hierarchy labels (dynamo_namespace, dynamo_component, dynamo_endpoint) are automatically
        injected into engine metrics to align Python metrics with Rust auto-labels.
        Additional labels can be provided via inject_labels parameter.
211
212
    """
    if config.engine_args.disable_log_stats is False:
213
214
215
216
217
218
219
220
221
222
223
        # Register the dedicated dynamo_component registry callback
        # IMPORTANT: We do NOT use MultiProcessCollector for DYNAMO_COMPONENT_REGISTRY
        # because our gauges use in-memory values which work fine for single-process
        # and multi-process (each process has its own gauge with dp_rank label).
        # Using MultiProcessCollector would read from .db files which causes stale
        # values to accumulate across test runs.
        register_engine_metrics_callback(
            endpoint=generate_endpoint,
            registry=DYNAMO_COMPONENT_REGISTRY,
        )

224
225
226
227
228
229
230
231
232
        if os.environ.get("PROMETHEUS_MULTIPROC_DIR"):
            try:
                # MultiProcessCollector reads metrics from .db files in PROMETHEUS_MULTIPROC_DIR
                # Adding it to REGISTRY allows collecting both in-memory and .db file metrics
                multiprocess.MultiProcessCollector(REGISTRY)
                logger.debug("Added MultiProcessCollector to global REGISTRY")
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
233
234
235
236
237
                    metric_prefix_filters=["vllm:", "lmcache:"],
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
238
239
240
241
242
243
244
245
246
247
248
                )
            except ValueError as e:
                # Conflict: metrics already in REGISTRY, MultiProcessCollector tries to add same metrics from .db files
                # Solution: Use separate registry that ONLY reads from .db files (no in-memory conflicts)
                logger.debug(
                    f"Could not add MultiProcessCollector to REGISTRY ({e}), using separate registry"
                )
                multiproc_registry = CollectorRegistry()
                multiprocess.MultiProcessCollector(multiproc_registry)

                # Register both registries to collect all metrics
249
                # Global REGISTRY has in-memory metrics (vllm)
250
251
252
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
253
                    metric_prefix_filters=["vllm:"],
254
255
256
257
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
258
259
260
261
262
                )
                # Multiproc registry has .db file metrics (lmcache, possibly vllm duplicates)
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=multiproc_registry,
263
264
265
266
267
                    metric_prefix_filters=["vllm:", "lmcache:"],
                    namespace_name=config.namespace,
                    component_name=config.component,
                    endpoint_name=config.endpoint,
                    model_name=config.model,
268
269
270
271
272
273
274
                )
        else:
            # No multiprocess mode
            register_engine_metrics_callback(
                endpoint=generate_endpoint,
                registry=REGISTRY,
                metric_prefix_filters=["vllm:", "lmcache:"],
275
276
277
278
                namespace_name=config.namespace,
                component_name=config.component,
                endpoint_name=config.endpoint,
                model_name=config.model,
279
280
281
            )


Yan Ru Pei's avatar
Yan Ru Pei committed
282
283
284
285
286
def setup_kv_event_publisher(
    config: Config,
    component,
    generate_endpoint,
    vllm_config,
287
288
    consolidator_enabled: bool = False,
    consolidator_port: Optional[int] = 5558,
289
) -> Optional[KvEventPublisher]:
Yan Ru Pei's avatar
Yan Ru Pei committed
290
    """
Yan Ru Pei's avatar
Yan Ru Pei committed
291
292
    Set up KV event publishers for prefix caching if enabled.
    Creates one publisher per dp_rank since each dp_rank publishes to a different port.
293
294
295
296
297
298
299
300
    Args:
        config: Worker configuration
        component: Component for runtime integration
        generate_endpoint: Endpoint for worker ID
        vllm_config: vLLM configuration
        consolidator_enabled: If True, subscribe to kv eventconsolidator's ZMQ endpoint
        consolidator_port: Port where kv event consolidator publishes (default: 5558)

Yan Ru Pei's avatar
Yan Ru Pei committed
301
    Returns:
302
        List of KvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
Yan Ru Pei's avatar
Yan Ru Pei committed
303
304
305
306
    """
    if not config.engine_args.enable_prefix_caching:
        return None

307
308
309
310
311
    # Skip KV event publishing for decode workers
    if config.is_decode_worker:
        logger.info("Skipping KV event publisher setup for decode worker")
        return None

312
313
314
    if config.engine_args.kv_events_config is None:
        return None

315
316
317
318
319
320
321
    # Check if kv_cache_events are explicitly disabled
    if not config.engine_args.kv_events_config.enable_kv_cache_events:
        logger.info(
            "KV event publishing skipped: enable_kv_cache_events=False in kv_events_config"
        )
        return None

Yan Ru Pei's avatar
Yan Ru Pei committed
322
323
324
325
326
    # Get data_parallel_size to create publishers for all dp_ranks
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    kv_publishers = []

    for dp_rank in range(data_parallel_size):
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        if consolidator_enabled:
            # TODO: Use different port for each dp_rank once KVBM supports DP
            zmq_endpoint = f"tcp://127.0.0.1:{consolidator_port}"
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to consolidator at {zmq_endpoint}"
            )
        else:
            # Each dp_rank publishes to a different port
            zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
                config.engine_args.kv_events_config.endpoint,
                data_parallel_rank=dp_rank,
            ).replace("*", "127.0.0.1")
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to vLLM at {zmq_endpoint}"
            )
Yan Ru Pei's avatar
Yan Ru Pei committed
342

343
344
        kv_publisher = KvEventPublisher(
            component=component,
Yan Ru Pei's avatar
Yan Ru Pei committed
345
346
            kv_block_size=vllm_config.cache_config.block_size,
            zmq_endpoint=zmq_endpoint,
347
            zmq_topic="",
348
            enable_local_indexer=config.enable_local_indexer,
349
            dp_rank=dp_rank,
Yan Ru Pei's avatar
Yan Ru Pei committed
350
351
        )
        kv_publishers.append(kv_publisher)
Yan Ru Pei's avatar
Yan Ru Pei committed
352

Yan Ru Pei's avatar
Yan Ru Pei committed
353
354
355
        logger.info(
            f"Worker reading KV events for dp_rank={dp_rank} from {zmq_endpoint}"
        )
Yan Ru Pei's avatar
Yan Ru Pei committed
356

Yan Ru Pei's avatar
Yan Ru Pei committed
357
    return kv_publishers if kv_publishers else None
Yan Ru Pei's avatar
Yan Ru Pei committed
358
359


Alec's avatar
Alec committed
360
def setup_vllm_engine(config, stat_logger=None):
361
362
363
    # vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
    # instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
    # ourselves to avoid this and handle cleanup properly.
364
365
366
367
368
369
370
371
    prometheus_temp_dir = None
    if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
        prometheus_temp_dir = tempfile.TemporaryDirectory(prefix="vllm_prometheus_")
        os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_temp_dir.name
        logger.debug(
            f"Created PROMETHEUS_MULTIPROC_DIR at: {os.environ['PROMETHEUS_MULTIPROC_DIR']}"
        )

372
    setup_multiprocess_prometheus()  # call vLLM's library's function to setup multiprocess prometheus
373
374
375
376
    logger.debug(
        f"Prometheus multiproc dir set to: {os.environ.get('PROMETHEUS_MULTIPROC_DIR')}"
    )

377
378
379
380
381
382
383
384
385
386
387
388
389
    # Construct Prometheus gauges AFTER setup_multiprocess_prometheus() so Gauge objects
    # see the correct ValueClass (multiprocess vs in-memory).
    component_gauges = LLMBackendMetrics(
        registry=DYNAMO_COMPONENT_REGISTRY,
        model_name=config.served_model_name or "",
        component_name=config.component or "",
    )

    # If a StatLoggerFactory was provided, give it the gauges so the loggers
    # it creates can publish Prometheus metrics.
    if stat_logger is not None:
        stat_logger.component_gauges = component_gauges

Alec's avatar
Alec committed
390
391
392
393
    os.environ["VLLM_NO_USAGE_STATS"] = "1"  # Avoid internal HTTP requests
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

    engine_args = config.engine_args
394

395
396
397
398
399
    if engine_args.enable_lora:
        if "VLLM_ALLOW_RUNTIME_LORA_UPDATING" not in os.environ:
            os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
        if "VLLM_LORA_MODULES_LOADING_TIMEOUT" not in os.environ:
            os.environ["VLLM_LORA_MODULES_LOADING_TIMEOUT"] = "600"
400
401

    if engine_args.load_format == "gms":
402
        engine_args.worker_cls = "gpu_memory_service.integrations.vllm.worker.GMSWorker"
403

Alec's avatar
Alec committed
404
405
406
407
408
409
410
411
412
    # Load default sampling params from `generation_config.json`
    default_sampling_params = (
        engine_args.create_model_config().get_diff_sampling_param()
    )

    # Taken from build_async_engine_client_from_engine_args()
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

413
414
415
    # Set up consolidator endpoints if KVBM is enabled
    consolidator_endpoints = None
    if config.has_connector("kvbm"):
416
417
418
419
420
421
422
423
424
425
426
427
        try:
            from kvbm.vllm_integration.consolidator_config import (
                get_consolidator_endpoints,
            )

            consolidator_endpoints = get_consolidator_endpoints(vllm_config)
        except Exception as e:
            logger.warning(
                f"KVBM connector is enabled but failed to get consolidator endpoints: {e}. "
                "Continuing without KV event consolidation. "
                "Ensure 'kvbm' package is installed if this feature is needed."
            )
428
429
    vllm_config.consolidator_endpoints = consolidator_endpoints

Alec's avatar
Alec committed
430
431
432
433
    factory = []
    if stat_logger:
        factory.append(stat_logger)

434
435
    # Time engine initialization
    start_time = time.time()
Alec's avatar
Alec committed
436
437
438
439
    engine_client = AsyncLLM.from_vllm_config(
        vllm_config=vllm_config,
        usage_context=usage_context,
        stat_loggers=factory,
440
        enable_log_requests=engine_args.enable_log_requests,
Alec's avatar
Alec committed
441
442
        disable_log_stats=engine_args.disable_log_stats,
    )
443
444
445
446
    load_time = time.time() - start_time

    # Record model load time
    component_gauges.set_model_load_time(load_time)
447
448

    logger.info(f"VllmWorker for {config.served_model_name} has been initialized")
449

450
451
452
453
454
455
456
    return (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
        component_gauges,
    )
Alec's avatar
Alec committed
457
458


459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
async def register_vllm_model(
    model_input: ModelInput,
    model_type: ModelType,
    generate_endpoint,
    config: Config,
    engine_client: AsyncLLM,
    vllm_config,
):
    """
    Helper function to register a vLLM model with runtime configuration.

    Args:
        model_input: Input type for the model (e.g., ModelInput.Tokens)
        model_type: Type of model (e.g., ModelType.Chat, ModelType.Prefill)
        generate_endpoint: Endpoint to register
        config: Configuration object
        engine_client: vLLM engine client
        vllm_config: vLLM configuration
    """
    runtime_config = ModelRuntimeConfig()

    # Get runtime configuration from vLLM engine
    logging.info(
        f"Getting engine runtime configuration metadata from vLLM engine for {model_type}..."
    )
    runtime_values = get_engine_cache_info(engine_client)
    runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
    runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
    runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
488
489
490
491
    # Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
    runtime_config.enable_local_indexer = (
        config.enable_local_indexer and not config.is_decode_worker
    )
492
493
494

    # Add tool/reasoning parsers for decode models
    if model_type != ModelType.Prefill:
495
496
        runtime_config.tool_call_parser = config.dyn_tool_call_parser
        runtime_config.reasoning_parser = config.dyn_reasoning_parser
497
498
499
500
501

    # Get data_parallel_size from vllm_config (defaults to 1)
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    runtime_config.data_parallel_size = data_parallel_size

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
    # Configure media decoder for frontend image decoding when enabled
    # This enables frontend to decode images and transfer via NIXL RDMA
    media_decoder = None
    media_fetcher = None
    if config.frontend_decoding:
        if not MEDIA_DECODER_AVAILABLE:
            raise RuntimeError(
                "--frontend-decoding requires MediaDecoder support. "
                "Ensure dynamo.llm module includes MediaDecoder and MediaFetcher."
            )
        media_decoder = MediaDecoder()
        media_decoder.enable_image({"limits": {"max_alloc": 128 * 1024 * 1024}})
        # media_decoder.enable_video({})

        media_fetcher = MediaFetcher()
        media_fetcher.timeout_ms(30000)
518
        media_fetcher.allow_direct_port(True)
519

520
521
522
523
524
525
    await register_llm(
        model_input,
        model_type,
        generate_endpoint,
        config.model,
        config.served_model_name,
526
        kv_cache_block_size=runtime_values["block_size"],
527
528
        runtime_config=runtime_config,
        custom_template_path=config.custom_jinja_template,
529
530
        media_decoder=media_decoder,
        media_fetcher=media_fetcher,
531
532
533
    )


534
async def init_prefill(
535
536
537
538
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    pre_created_engine=None,
539
):
Alec's avatar
Alec committed
540
541
542
    """
    Instantiate and serve
    """
543
    component = runtime.namespace(config.namespace).component(config.component)
Alec's avatar
Alec committed
544
545
546
547

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")

548
549
550
551
552
553
554
    # Use pre-created engine if provided (checkpoint mode), otherwise create new
    if pre_created_engine is not None:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
555
            _component_gauges,
556
557
558
559
560
561
562
        ) = pre_created_engine
    else:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
563
            _component_gauges,
564
        ) = setup_vllm_engine(config)
Alec's avatar
Alec committed
565

566
    handler = PrefillWorkerHandler(
567
568
569
570
571
        runtime,
        component,
        engine_client,
        default_sampling_params,
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
572
        enable_multimodal=config.enable_multimodal,
573
574
        generate_endpoint=generate_endpoint,
        config=config,
575
        use_vllm_tokenizer=config.use_vllm_tokenizer,
576
        shutdown_event=shutdown_event,
577
        enable_frontend_decoding=config.frontend_decoding,
578
    )
579
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
580

581
582
583
584
585
586
587
588
589
590
591
592
593
594
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

Yan Ru Pei's avatar
Yan Ru Pei committed
595
    # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
596
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
597
    kv_publishers = setup_kv_event_publisher(
598
599
600
601
602
603
        config,
        component,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
604
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
605
606
    if kv_publishers:
        handler.kv_publishers = kv_publishers
607

608
    setup_metrics_collection(config, generate_endpoint, logger)
609

610
    # Register sleep/wake_up engine routes
611
    runtime.register_engine_route("sleep", handler.sleep)
612
613
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
614

615
616
617
618
619
    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return

620
    # Register prefill model with ModelType.Prefill
621
622
623
624
625
626
627
628
629
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
    await register_vllm_model(
        model_input,
        ModelType.Prefill,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
    )
630

631
632
633
    health_check_payload = VllmPrefillHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
634

Alec's avatar
Alec committed
635
    try:
636
        logger.debug("Starting serve_endpoint for prefill worker")
Alec's avatar
Alec committed
637
        await asyncio.gather(
638
639
640
641
            # for prefill, we want to shutdown the engine after all prefill requests are finished because
            #     (temp reason): we don't support re-routing prefill requests
            #     (long-term reason): prefill engine should pull from a global queue so there is
            #                         only a few in-flight requests that can be quickly finished
642
643
644
            generate_endpoint.serve_endpoint(
                handler.generate,
                graceful_shutdown=True,
645
                # In practice config.served_model_name is always set, but mypy needs the "or" here.
646
647
648
649
650
651
652
653
654
655
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
656
                health_check_payload=health_check_payload,
657
658
            ),
            clear_endpoint.serve_endpoint(
659
                handler.clear_kv_blocks,
660
661
662
663
                metrics_labels=[
                    (prometheus_names.labels.MODEL, config.served_model_name),
                    (prometheus_names.labels.MODEL_NAME, config.served_model_name),
                ],
664
            ),
Alec's avatar
Alec committed
665
        )
666
        logger.debug("serve_endpoint completed for prefill worker")
Alec's avatar
Alec committed
667
668
669
670
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
671
        logger.debug("Cleaning up prefill worker")
Alec's avatar
Alec committed
672
673
674
        handler.cleanup()


675
async def init(
676
677
678
679
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    pre_created_engine=None,
680
):
Alec's avatar
Alec committed
681
682
683
684
    """
    Instantiate and serve
    """

685
    component = runtime.namespace(config.namespace).component(config.component)
Alec's avatar
Alec committed
686
687
688

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")
689
690
691
    load_lora_endpoint = component.endpoint("load_lora")
    unload_lora_endpoint = component.endpoint("unload_lora")
    list_loras_endpoint = component.endpoint("list_loras")
Alec's avatar
Alec committed
692

693
    model_name = config.served_model_name or config.model
694
695
696
697
698
699
700
701

    # Use pre-created engine if provided (checkpoint mode), otherwise create new
    if pre_created_engine is not None:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
702
            component_gauges,
703
        ) = pre_created_engine
704
705
706
707
708
709
710
        # Factory is created after unpack so component_gauges is available
        factory = StatLoggerFactory(
            component,
            component_gauges=component_gauges,
            dp_rank=config.engine_args.data_parallel_rank or 0,
            metrics_labels=[("model", model_name)],
        )
711
    else:
712
713
714
715
716
717
718
719
        # Factory is created without component_gauges; setup_vllm_engine() will
        # create the gauges after setup_multiprocess_prometheus() and set them
        # on the factory before vLLM calls create_stat_logger().
        factory = StatLoggerFactory(
            component,
            dp_rank=config.engine_args.data_parallel_rank or 0,
            metrics_labels=[("model", model_name)],
        )
720
721
722
723
724
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
725
            component_gauges,
726
        ) = setup_vllm_engine(config, factory)
Alec's avatar
Alec committed
727

728
    # TODO Hack to get data, move this to registering in TBD
Alec's avatar
Alec committed
729
730
731
732
    factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
    factory.init_publish()

    handler = DecodeWorkerHandler(
733
734
735
736
        runtime,
        component,
        engine_client,
        default_sampling_params,
737
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
738
        enable_multimodal=config.enable_multimodal,
739
740
        generate_endpoint=generate_endpoint,
        config=config,
741
        use_vllm_tokenizer=config.use_vllm_tokenizer,
742
        shutdown_event=shutdown_event,
743
        enable_frontend_decoding=config.frontend_decoding,
Alec's avatar
Alec committed
744
    )
745
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
746

747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

    # Set up KV event publisher for prefix caching if enabled
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
763
    kv_publishers = setup_kv_event_publisher(
764
765
766
767
768
769
        config,
        component,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
Yan Ru Pei's avatar
Yan Ru Pei committed
770
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
771
772
    if kv_publishers:
        handler.kv_publishers = kv_publishers
773

774
    setup_metrics_collection(config, generate_endpoint, logger)
775

776
    # Register sleep/wake_up engine routes
777
    runtime.register_engine_route("sleep", handler.sleep)
778
779
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
780

781
782
783
784
    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return
785

786
787
788
    # Parse endpoint types from --endpoint-types flag
    model_type = parse_endpoint_types(config.endpoint_types)
    logger.info(f"Registering model with endpoint types: {config.endpoint_types}")
789

790
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
791

792
    # Warn if custom template provided but chat endpoint not enabled
793
    if config.custom_jinja_template and "chat" not in config.endpoint_types:
794
795
796
        logger.warning(
            "Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
            "The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
797
798
        )

799
800
801
802
803
804
805
806
807
    await register_vllm_model(
        model_input,
        model_type,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
    )

808
809
810
    health_check_payload = VllmHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
811

Alec's avatar
Alec committed
812
    try:
813
        logger.debug("Starting serve_endpoint for decode worker")
Alec's avatar
Alec committed
814
        await asyncio.gather(
815
816
            # for decode, we want to transfer the in-flight requests to other decode engines,
            # because waiting them to finish can take a long time for long OSLs
817
818
            generate_endpoint.serve_endpoint(
                handler.generate,
819
                graceful_shutdown=True,
820
821
822
823
824
825
826
827
828
829
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
830
                health_check_payload=health_check_payload,
831
832
            ),
            clear_endpoint.serve_endpoint(
833
                handler.clear_kv_blocks,
834
835
836
837
838
839
840
841
842
843
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
844
            ),
845
846
            load_lora_endpoint.serve_endpoint(
                handler.load_lora,
847
848
849
850
851
852
853
854
855
856
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
857
858
859
            ),
            unload_lora_endpoint.serve_endpoint(
                handler.unload_lora,
860
861
862
863
864
865
866
867
868
869
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
870
871
872
            ),
            list_loras_endpoint.serve_endpoint(
                handler.list_loras,
873
874
875
876
877
878
879
880
881
882
                metrics_labels=[
                    (
                        prometheus_names.labels.MODEL,
                        config.served_model_name or config.model,
                    ),
                    (
                        prometheus_names.labels.MODEL_NAME,
                        config.served_model_name or config.model,
                    ),
                ],
883
            ),
Alec's avatar
Alec committed
884
        )
885
        logger.debug("serve_endpoint completed for decode worker")
Alec's avatar
Alec committed
886
887
888
889
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
890
        logger.debug("Cleaning up decode worker")
Alec's avatar
Alec committed
891
892
893
894
        # Cleanup background tasks
        handler.cleanup()


895
896
897
898
899
900
901
def get_engine_cache_info(engine: AsyncLLM):
    """Retrieve cache configuration information from [`AsyncLLM`] engine."""

    try:
        # Get values directly from vllm_config instead of collective_rpc
        cache_values = {
            "num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks,
902
            "block_size": engine.vllm_config.cache_config.block_size,
903
904
905
906
907
908
909
910
911
912
913
        }

        scheduler_values = {
            "max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs,
            "max_num_batched_tokens": engine.vllm_config.scheduler_config.max_num_batched_tokens,
        }

        logging.info(f"Cache config values: {cache_values}")
        logging.info(f"Scheduler config values: {scheduler_values}")
        return {
            "num_gpu_blocks": cache_values["num_gpu_blocks"],
914
            "block_size": cache_values["block_size"],
915
916
917
918
919
920
921
922
            "max_num_seqs": scheduler_values["max_num_seqs"],
            "max_num_batched_tokens": scheduler_values["max_num_batched_tokens"],
        }
    except Exception as e:
        logging.error(f"Failed to get configuration values from vLLM config: {e}")
        raise


923
924
925
async def init_multimodal_processor(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
926
927
928
929
930
931
932
933
934
935
936
937
938
    """Initialize multimodal processor component"""
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)

    # Get encode worker client
    encode_worker_client = (
        await runtime.namespace(config.namespace)
        .component("encoder")
        .endpoint("generate")
        .client()
    )

GuanLuo's avatar
GuanLuo committed
939
940
941
942
943
944
    pd_worker_client = (
        await runtime.namespace(config.namespace)
        .component("backend")
        .endpoint("generate")
        .client()
    )
945

GuanLuo's avatar
GuanLuo committed
946
    handler = PreprocessedHandler(
947
948
        config.engine_args,
        encode_worker_client,
GuanLuo's avatar
GuanLuo committed
949
        pd_worker_client,
950
951
952
953
954
955
956
    )

    logger.info("Waiting for Encoder Worker Instances ...")
    await encode_worker_client.wait_for_instances()

    # Register the endpoint as entrypoint to a model
    await register_llm(
GuanLuo's avatar
GuanLuo committed
957
        ModelInput.Tokens,
958
959
960
961
962
963
964
965
966
967
968
969
        ModelType.Chat,
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
    )

    logger.info("Starting to serve the processor endpoint...")

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
970
971
972
973
974
                handler.generate,
                metrics_labels=[
                    (prometheus_names.labels.MODEL, config.model),
                    (prometheus_names.labels.MODEL_NAME, config.model),
                ],
975
976
977
978
979
980
981
982
983
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
        handler.cleanup()


984
985
986
async def init_multimodal_encode_worker(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
    """Initialize multimodal encode worker component"""
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)

    # Get PD worker client
    # In multimodal mode, the PD worker always registers as "backend"
    # (even in disaggregated mode with prefill/decode split, we still connect to "backend")
    pd_worker_client = (
        await runtime.namespace(config.namespace)
        .component("backend")
        .endpoint("generate")
        .client()
    )

    handler = EncodeWorkerHandler(
        config.engine_args,
        pd_worker_client,
    )
    await handler.async_init(runtime)
    logger.info("Waiting for PD Worker Instances ...")
    await pd_worker_client.wait_for_instances()
    logger.info("Starting to serve the encode worker endpoint...")

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
1014
1015
1016
1017
1018
                handler.generate,
                metrics_labels=[
                    (prometheus_names.labels.MODEL, config.model),
                    (prometheus_names.labels.MODEL_NAME, config.model),
                ],
1019
1020
1021
            ),
        )
    except Exception as e:
1022
1023
1024
1025
1026
1027
        logger.error(f"Failed to serve encode worker endpoint: {e}")
        raise
    finally:
        handler.cleanup()


1028
1029
1030
async def init_vllm_native_encoder(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
    """
    Initialize vLLM-native encoder worker component (ECConnector mode).
    In this mode, vLLM handles encoder execution, caching, and storage automatically.
    """
    # Create component and endpoint
    component = runtime.namespace(config.namespace).component(config.component)
    generate_endpoint = component.endpoint(config.endpoint)

    # Configure ECTransferConfig for producer role
    instance_id = 0
    engine_id = f"{config.namespace}.{config.component}.encoder.{instance_id}"

    # Configure encoder with producer role, it will be responsible for creating embeddings and storing them in the shared storage
    ec_transfer_config = create_ec_transfer_config(
        engine_id=engine_id,
        ec_role="ec_producer",
        ec_connector_backend=config.ec_connector_backend,
        ec_storage_path=config.ec_storage_path,
        ec_extra_config=config.ec_extra_config,
    )

    # Set ECTransferConfig on engine args
    config.engine_args.ec_transfer_config = ec_transfer_config

    # Setup vLLM engine
    (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
1061
        _component_gauges,
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
    ) = setup_vllm_engine(config)

    # Initialize vLLM Native Encoder Worker Handler
    handler = VLLMEncodeWorkerHandler(
        runtime,
        component,
        engine_client,
        config,
    )
    handler.add_temp_dir(prometheus_temp_dir)

    # 5. No async init needed - vLLM handles everything
    # await handler.async_init(runtime)  # Not needed for ECConnector mode

    logger.info("Starting to serve vLLM-native encoder endpoint...")

    # 6. Serve endpoint
    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
1082
1083
1084
1085
1086
                handler.generate,
                metrics_labels=[
                    (prometheus_names.labels.MODEL, config.model),
                    (prometheus_names.labels.MODEL_NAME, config.model),
                ],
1087
1088
1089
1090
1091
1092
1093
1094
1095
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve vLLM-native encoder endpoint: {e}")
        raise
    finally:
        handler.cleanup()


1096
1097
1098
async def init_ec_processor(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
    """
    Initialize ECConnector processor component.

    Simple processor that routes multimodal requests using ECConnector pattern:
    1. Preprocess request (same as regular processor)
    2. Send multimodal items to encoder workers (stores to shared storage)
    3. Forward preprocessed request to PD worker (loads from shared storage)
    4. Stream response back to client
    """
    # Create component and endpoint
    component = runtime.namespace(config.namespace).component(config.component)
    generate_endpoint = component.endpoint(config.endpoint)

    # Get encoder worker client
    encoder_client = (
        await runtime.namespace(config.namespace)
        .component("encoder")
        .endpoint("generate")
        .client()
    )

    # Get PD worker client
    pd_client = (
        await runtime.namespace(config.namespace)
        .component("backend")
        .endpoint("generate")
        .client()
    )

    # Get prompt template from args (must be passed via environment or command line)
    mm_prompt_template = config.mm_prompt_template

    # Create EC processor handler (with preprocessing like regular processor)
    handler = ECProcessorHandler(
        config.engine_args,
        encoder_worker_client=encoder_client,
        pd_worker_client=pd_client,
        prompt_template=mm_prompt_template,
    )

    logger.info("Waiting for encoder and PD worker instances...")
    await encoder_client.wait_for_instances()
    await pd_client.wait_for_instances()

1143
    # Register the endpoint as entrypoint to a model (same as preprocessed_handler)
1144
    await register_llm(
1145
        ModelInput.Tokens,  # Use Rust tokenization for better performance and multi-image support
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
        ModelType.Chat,
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
    )

    logger.info("Starting to serve EC processor endpoint...")

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
1158
1159
1160
1161
1162
                handler.generate,
                metrics_labels=[
                    (prometheus_names.labels.MODEL, config.model),
                    (prometheus_names.labels.MODEL_NAME, config.model),
                ],
1163
1164
1165
1166
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve EC processor endpoint: {e}")
1167
1168
1169
1170
1171
        raise
    finally:
        handler.cleanup()


1172
async def init_multimodal_worker(
1173
1174
1175
1176
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    pre_created_engine=None,
1177
):
1178
1179
1180
1181
1182
1183
    """
    Initialize multimodal worker component.

    Supports two modes:
    1. --multimodal-worker: Receives embeddings from separate encoder
    2. --multimodal-encode-prefill-worker: Handles inline encoding (e.g., Llama 4)
1184

1185
    Both can operate in aggregated (P+D) or disaggregated (P→D) mode.
1186
1187
1188

    When --ec-consumer-mode is enabled, configures as ECConnector consumer
    to load encoder embeddings from shared storage.
1189
    """
1190
1191
1192
1193
1194
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")

1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
    # Configure ECConnector consumer mode if enabled
    if config.ec_consumer_mode:
        logger.info("Configuring as ECConnector consumer for encoder embeddings")
        instance_id = 0
        engine_id = f"{config.namespace}.{config.component}.backend.{instance_id}"

        # The PD Worker just load the embeddings from the shared storage, so it is a consumer
        ec_transfer_config = create_ec_transfer_config(
            engine_id=engine_id,
            ec_role="ec_consumer",
            ec_connector_backend=config.ec_connector_backend,
            ec_storage_path=config.ec_storage_path,
            ec_extra_config=config.ec_extra_config,
        )

        # Set ECTransferConfig on engine args
        config.engine_args.ec_transfer_config = ec_transfer_config
        logger.info(f"Configured as ECConnector consumer with engine_id={engine_id}")

1214
1215
1216
1217
1218
1219
1220
    # Use pre-created engine if provided (checkpoint mode), otherwise create new
    if pre_created_engine is not None:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
1221
            _component_gauges,
1222
1223
1224
1225
1226
1227
1228
        ) = pre_created_engine
    else:
        (
            engine_client,
            vllm_config,
            default_sampling_params,
            prometheus_temp_dir,
1229
            _component_gauges,
1230
        ) = setup_vllm_engine(config)
1231

Ayush Agarwal's avatar
Ayush Agarwal committed
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
    # Set up decode worker client for disaggregated mode
    decode_worker_client = None
    if config.is_prefill_worker:
        # Prefill worker needs to connect to decode worker
        decode_worker_client = (
            await runtime.namespace(config.namespace)
            .component("decoder")
            .endpoint("generate")
            .client()
        )
        await decode_worker_client.wait_for_instances()
        logger.info("Connected to decode worker for disaggregated mode")
1244

Ayush Agarwal's avatar
Ayush Agarwal committed
1245
1246
1247
    # Choose handler based on worker type
    if config.multimodal_decode_worker:
        handler = MultimodalDecodeWorkerHandler(
1248
            runtime, component, engine_client, config, shutdown_event
Ayush Agarwal's avatar
Ayush Agarwal committed
1249
1250
1251
        )
    else:
        handler = MultimodalPDWorkerHandler(
1252
1253
1254
1255
1256
1257
            runtime,
            component,
            engine_client,
            config,
            decode_worker_client,
            shutdown_event,
Ayush Agarwal's avatar
Ayush Agarwal committed
1258
        )
1259
    handler.add_temp_dir(prometheus_temp_dir)
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269

    await handler.async_init(runtime)

    # Set up KV event publisher for prefix caching if enabled
    kv_publisher = setup_kv_event_publisher(
        config, component, generate_endpoint, vllm_config
    )
    if kv_publisher:
        handler.kv_publisher = kv_publisher

1270
1271
1272
1273
    metrics_labels = [
        (prometheus_names.labels.MODEL, config.model),
        (prometheus_names.labels.MODEL_NAME, config.model),
    ]
1274
1275
1276
    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
1277
1278
                handler.generate,
                metrics_labels=metrics_labels,
1279
1280
            ),
            clear_endpoint.serve_endpoint(
1281
1282
                handler.clear_kv_blocks,
                metrics_labels=metrics_labels,
1283
1284
1285
1286
1287
1288
1289
1290
1291
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
        handler.cleanup()


1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
async def init_omni(
    runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
    """
    Initialize Omni worker for text-to-text generation using vLLM-Omni orchestrator.

    Uses vLLM-Omni's Omni class for single-stage text generation pipeline.
    For now, supports text-to-text only (no multimodal).
    """
    # Lazy import to avoid loading vllm-omni unless explicitly needed
    from dynamo.vllm.omni import OmniHandler

    component = runtime.namespace(config.namespace).component(config.component)
    generate_endpoint = component.endpoint(config.endpoint)

    # Initialize OmniHandler with Omni orchestrator
    handler = OmniHandler(
        runtime=runtime,
        component=component,
        config=config,
1312
        default_sampling_params={},
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
        shutdown_event=shutdown_event,
    )

    logger.info(f"Omni worker initialized for model: {config.model}")

    # Set up metrics collection for vLLM and LMCache metrics
    setup_metrics_collection(config, generate_endpoint, logger)

    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return

    # TODO: extend for multi-stage pipelines
    await register_llm(
1328
1329
        ModelInput.Text,
        ModelType.Images,
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
    )

    logger.info("Starting to serve Omni worker endpoint...")

    health_check_payload = (
        await VllmOmniHealthCheckPayload.create(handler.engine_client)
    ).to_dict()

    try:
        await generate_endpoint.serve_endpoint(
            handler.generate,
            graceful_shutdown=True,
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
            metrics_labels=[
                (
                    prometheus_names.labels.MODEL,
                    config.served_model_name or config.model,
                ),
                (
                    prometheus_names.labels.MODEL_NAME,
                    config.served_model_name or config.model,
                ),
            ],
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
            health_check_payload=health_check_payload,
        )
    except Exception as e:
        logger.error(f"Failed to serve Omni endpoint: {e}")
        raise
    finally:
        logger.debug("Cleaning up Omni worker")
        handler.cleanup()


Alec's avatar
Alec committed
1366
1367
1368
1369
def main():
    uvloop.run(worker())


Alec's avatar
Alec committed
1370
if __name__ == "__main__":
Alec's avatar
Alec committed
1371
    main()