main.py 37.4 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Alec's avatar
Alec committed
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging
import os
import signal
8
import tempfile
9
from typing import Optional
Alec's avatar
Alec committed
10
11

import uvloop
12
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
Alec's avatar
Alec committed
13
14
15
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
16
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
Alec's avatar
Alec committed
17

18
from dynamo.common.config_dump import dump_config
19
from dynamo.common.utils.endpoint_types import parse_endpoint_types
20
from dynamo.common.utils.prometheus import register_engine_metrics_callback
Alec's avatar
Alec committed
21
from dynamo.llm import (
22
    ModelInput,
23
    ModelRuntimeConfig,
Alec's avatar
Alec committed
24
25
26
    ModelType,
    ZmqKvEventPublisher,
    ZmqKvEventPublisherConfig,
27
    fetch_llm,
Alec's avatar
Alec committed
28
29
    register_llm,
)
30
from dynamo.runtime import DistributedRuntime
Alec's avatar
Alec committed
31
from dynamo.runtime.logging import configure_dynamo_logging
32
from dynamo.vllm.multimodal_handlers import (
33
    ECProcessorHandler,
34
    EncodeWorkerHandler,
Ayush Agarwal's avatar
Ayush Agarwal committed
35
    MultimodalDecodeWorkerHandler,
36
    MultimodalPDWorkerHandler,
GuanLuo's avatar
GuanLuo committed
37
    PreprocessedHandler,
38
    VLLMEncodeWorkerHandler,
39
)
40
from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config
Alec's avatar
Alec committed
41

42
from .args import Config, overwrite_args, parse_args
Alec's avatar
Alec committed
43
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
44
from .health_check import VllmHealthCheckPayload, VllmPrefillHealthCheckPayload
Alec's avatar
Alec committed
45
46
from .publisher import StatLoggerFactory

Alec's avatar
Alec committed
47
48
49
50
configure_dynamo_logging()
logger = logging.getLogger(__name__)


51
52
53
54
55
56
57
58
59
60
61
62
63
async def _handle_non_leader_node(dp_rank: int) -> None:
    """
    Handle non-leader node (data_parallel_rank >= 1) in multi-node deployments.
    Non-leader nodes run vLLM workers but don't serve Dynamo endpoints.
    """
    logger.info(
        f"Non-leader node detected (data_parallel_rank={dp_rank}). "
        "Skipping endpoint serving."
    )
    # Wait indefinitely - process terminated via signal handlers
    await asyncio.Event().wait()


Alec's avatar
Alec committed
64
65
async def graceful_shutdown(runtime):
    """
66
67
68
69
    Shutdown dynamo distributed runtime.
    The endpoints will be immediately invalidated so no new requests will be accepted.
    For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
    For endpoints served with graceful_shutdown=False, the serving function will return immediately.
Alec's avatar
Alec committed
70
71
72
73
74
75
    """
    logging.info("Received shutdown signal, shutting down DistributedRuntime")
    runtime.shutdown()
    logging.info("DistributedRuntime shutdown complete")


76
async def worker():
Alec's avatar
Alec committed
77
78
    config = parse_args()

79
    loop = asyncio.get_running_loop()
80
81
    overwrite_args(config)

82
83
84
85
86
    # Enable NATS based on use_kv_events flag (derived from kv_events_config)
    runtime = DistributedRuntime(
        loop, config.store_kv, config.request_plane, config.use_kv_events
    )

Alec's avatar
Alec committed
87
88
89
90
91
92
93
    # Set up signal handler for graceful shutdown
    def signal_handler():
        asyncio.create_task(graceful_shutdown(runtime))

    for sig in (signal.SIGTERM, signal.SIGINT):
        loop.add_signal_handler(sig, signal_handler)

94
    logging.debug("Signal handlers set up for graceful shutdown")
Alec's avatar
Alec committed
95

96
    dump_config(config.dump_config_to, config)
97

98
99
    # Name the model. Use either the full path (vllm and sglang do the same),
    # or the HF name (e.g. "Qwen/Qwen3-0.6B"), depending on cmd line params.
100
101
    if not config.served_model_name:
        config.served_model_name = config.engine_args.served_model_name = config.model
102
103
104
105
106
107
108
109
110
111

    # Download the model if necessary using modelexpress.
    # We want it on disk before we start vllm to avoid downloading from HuggingFace.
    #
    # We don't set `config.engine_args.model` to the local path fetch_llm returns
    # because vllm will send that name to its Ray pipeline-parallel workers, which
    # may not have the local path.
    # vllm will attempt to download the model again, but find it in the HF cache.
    # For non-HF models use a path instead of an HF name, and ensure all workers have
    # that path (ideally via a shared folder).
112
    if not os.path.exists(config.model):
113
        await fetch_llm(config.model)
114

115
    # Route to appropriate initialization based on config flags
116
117
118
119
120
121
122
    if config.vllm_native_encoder_worker:
        await init_vllm_native_encoder(runtime, config)
        logger.debug("init_vllm_native_encoder completed")
    elif config.ec_processor:
        await init_ec_processor(runtime, config)
        logger.debug("init_ec_processor completed")
    elif config.multimodal_processor:
123
124
125
126
127
        await init_multimodal_processor(runtime, config)
        logger.debug("init_multimodal_processor completed")
    elif config.multimodal_encode_worker:
        await init_multimodal_encode_worker(runtime, config)
        logger.debug("init_multimodal_encode_worker completed")
Ayush Agarwal's avatar
Ayush Agarwal committed
128
129
130
131
132
    elif (
        config.multimodal_worker
        or config.multimodal_decode_worker
        or config.multimodal_encode_prefill_worker
    ):
133
134
135
        await init_multimodal_worker(runtime, config)
        logger.debug("init_multimodal_worker completed")
    elif config.is_prefill_worker:
Alec's avatar
Alec committed
136
        await init_prefill(runtime, config)
137
        logger.debug("init_prefill completed")
Alec's avatar
Alec committed
138
139
    else:
        await init(runtime, config)
140
141
142
        logger.debug("init completed")

    logger.debug("Worker function completed, exiting...")
Alec's avatar
Alec committed
143
144


145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def setup_metrics_collection(config: Config, generate_endpoint, logger):
    """Set up metrics collection for vLLM and LMCache metrics.

    In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored:
      1. In-memory: Metric objects in global REGISTRY
      2. On-disk: Metric values in .db files (PROMETHEUS_MULTIPROC_DIR)

    MultiProcessCollector reads from .db files but adding it to REGISTRY can fail
    with "Duplicated timeseries" if PROMETHEUS_MULTIPROC_DIR was set before process
    started (K8s deployments) because metrics are already in REGISTRY.

    Solution: Try adding MultiProcessCollector to REGISTRY. If that fails, use
    separate registry for multiprocess collection and register callbacks to both
    registries to ensure all metrics (vllm, lmcache, dynamo_component) are collected.
    """
    if config.engine_args.disable_log_stats is False:
        if os.environ.get("PROMETHEUS_MULTIPROC_DIR"):
            try:
                # MultiProcessCollector reads metrics from .db files in PROMETHEUS_MULTIPROC_DIR
                # Adding it to REGISTRY allows collecting both in-memory and .db file metrics
                multiprocess.MultiProcessCollector(REGISTRY)
                logger.debug("Added MultiProcessCollector to global REGISTRY")
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
                    metric_prefix_filters=["vllm:", "lmcache:"],
                )
            except ValueError as e:
                # Conflict: metrics already in REGISTRY, MultiProcessCollector tries to add same metrics from .db files
                # Solution: Use separate registry that ONLY reads from .db files (no in-memory conflicts)
                logger.debug(
                    f"Could not add MultiProcessCollector to REGISTRY ({e}), using separate registry"
                )
                multiproc_registry = CollectorRegistry()
                multiprocess.MultiProcessCollector(multiproc_registry)

                # Register both registries to collect all metrics
                # Global REGISTRY has in-memory metrics (vllm, dynamo_component)
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=REGISTRY,
                    metric_prefix_filters=["vllm:", "dynamo_component:"],
                )
                # Multiproc registry has .db file metrics (lmcache, possibly vllm duplicates)
                register_engine_metrics_callback(
                    endpoint=generate_endpoint,
                    registry=multiproc_registry,
                    metric_prefix_filters=["vllm:", "lmcache:"],
                )
        else:
            # No multiprocess mode
            register_engine_metrics_callback(
                endpoint=generate_endpoint,
                registry=REGISTRY,
                metric_prefix_filters=["vllm:", "lmcache:"],
            )


Yan Ru Pei's avatar
Yan Ru Pei committed
203
204
205
206
207
def setup_kv_event_publisher(
    config: Config,
    component,
    generate_endpoint,
    vllm_config,
208
209
210
    consolidator_enabled: bool = False,
    consolidator_port: Optional[int] = 5558,
) -> Optional[ZmqKvEventPublisher]:
Yan Ru Pei's avatar
Yan Ru Pei committed
211
    """
Yan Ru Pei's avatar
Yan Ru Pei committed
212
213
    Set up KV event publishers for prefix caching if enabled.
    Creates one publisher per dp_rank since each dp_rank publishes to a different port.
214
215
216
217
218
219
220
221
    Args:
        config: Worker configuration
        component: Component for runtime integration
        generate_endpoint: Endpoint for worker ID
        vllm_config: vLLM configuration
        consolidator_enabled: If True, subscribe to kv eventconsolidator's ZMQ endpoint
        consolidator_port: Port where kv event consolidator publishes (default: 5558)

Yan Ru Pei's avatar
Yan Ru Pei committed
222
    Returns:
Yan Ru Pei's avatar
Yan Ru Pei committed
223
        List of ZmqKvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
Yan Ru Pei's avatar
Yan Ru Pei committed
224
225
226
227
    """
    if not config.engine_args.enable_prefix_caching:
        return None

228
229
230
231
232
    # Skip KV event publishing for decode workers
    if config.is_decode_worker:
        logger.info("Skipping KV event publisher setup for decode worker")
        return None

233
234
235
    if config.engine_args.kv_events_config is None:
        return None

236
237
238
239
240
241
242
    # Check if kv_cache_events are explicitly disabled
    if not config.engine_args.kv_events_config.enable_kv_cache_events:
        logger.info(
            "KV event publishing skipped: enable_kv_cache_events=False in kv_events_config"
        )
        return None

Yan Ru Pei's avatar
Yan Ru Pei committed
243
244
245
246
247
    # Get data_parallel_size to create publishers for all dp_ranks
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    kv_publishers = []

    for dp_rank in range(data_parallel_size):
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        if consolidator_enabled:
            # TODO: Use different port for each dp_rank once KVBM supports DP
            zmq_endpoint = f"tcp://127.0.0.1:{consolidator_port}"
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to consolidator at {zmq_endpoint}"
            )
        else:
            # Each dp_rank publishes to a different port
            zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
                config.engine_args.kv_events_config.endpoint,
                data_parallel_rank=dp_rank,
            ).replace("*", "127.0.0.1")
            logger.info(
                f"KV event publisher for dp_rank={dp_rank} subscribing to vLLM at {zmq_endpoint}"
            )
Yan Ru Pei's avatar
Yan Ru Pei committed
263
264

        zmq_config = ZmqKvEventPublisherConfig(
265
            worker_id=generate_endpoint.connection_id(),
Yan Ru Pei's avatar
Yan Ru Pei committed
266
267
            kv_block_size=vllm_config.cache_config.block_size,
            zmq_endpoint=zmq_endpoint,
268
            enable_local_indexer=config.enable_local_indexer,
Yan Ru Pei's avatar
Yan Ru Pei committed
269
270
271
        )
        kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
        kv_publishers.append(kv_publisher)
Yan Ru Pei's avatar
Yan Ru Pei committed
272

Yan Ru Pei's avatar
Yan Ru Pei committed
273
274
275
        logger.info(
            f"Worker reading KV events for dp_rank={dp_rank} from {zmq_endpoint}"
        )
Yan Ru Pei's avatar
Yan Ru Pei committed
276

Yan Ru Pei's avatar
Yan Ru Pei committed
277
    return kv_publishers if kv_publishers else None
Yan Ru Pei's avatar
Yan Ru Pei committed
278
279


Alec's avatar
Alec committed
280
def setup_vllm_engine(config, stat_logger=None):
281
282
283
    # vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
    # instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
    # ourselves to avoid this and handle cleanup properly.
284
285
286
287
288
289
290
291
    prometheus_temp_dir = None
    if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
        prometheus_temp_dir = tempfile.TemporaryDirectory(prefix="vllm_prometheus_")
        os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_temp_dir.name
        logger.debug(
            f"Created PROMETHEUS_MULTIPROC_DIR at: {os.environ['PROMETHEUS_MULTIPROC_DIR']}"
        )

292
    setup_multiprocess_prometheus()  # call vLLM's library's function to setup multiprocess prometheus
293
294
295
296
    logger.debug(
        f"Prometheus multiproc dir set to: {os.environ.get('PROMETHEUS_MULTIPROC_DIR')}"
    )

Alec's avatar
Alec committed
297
298
299
300
    os.environ["VLLM_NO_USAGE_STATS"] = "1"  # Avoid internal HTTP requests
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

    engine_args = config.engine_args
301

302
303
304
305
306
    if engine_args.enable_lora:
        if "VLLM_ALLOW_RUNTIME_LORA_UPDATING" not in os.environ:
            os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
        if "VLLM_LORA_MODULES_LOADING_TIMEOUT" not in os.environ:
            os.environ["VLLM_LORA_MODULES_LOADING_TIMEOUT"] = "600"
307
308
309
310

    if engine_args.load_format == "gms":
        engine_args.worker_cls = "gpu_memory_service.vllm_integration.worker.GMSWorker"

Alec's avatar
Alec committed
311
312
313
314
315
316
317
318
319
    # Load default sampling params from `generation_config.json`
    default_sampling_params = (
        engine_args.create_model_config().get_diff_sampling_param()
    )

    # Taken from build_async_engine_client_from_engine_args()
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

320
321
322
    # Set up consolidator endpoints if KVBM is enabled
    consolidator_endpoints = None
    if config.has_connector("kvbm"):
323
324
325
326
327
328
329
330
331
332
333
334
        try:
            from kvbm.vllm_integration.consolidator_config import (
                get_consolidator_endpoints,
            )

            consolidator_endpoints = get_consolidator_endpoints(vllm_config)
        except Exception as e:
            logger.warning(
                f"KVBM connector is enabled but failed to get consolidator endpoints: {e}. "
                "Continuing without KV event consolidation. "
                "Ensure 'kvbm' package is installed if this feature is needed."
            )
335
336
    vllm_config.consolidator_endpoints = consolidator_endpoints

Alec's avatar
Alec committed
337
338
339
340
341
342
343
344
    factory = []
    if stat_logger:
        factory.append(stat_logger)

    engine_client = AsyncLLM.from_vllm_config(
        vllm_config=vllm_config,
        usage_context=usage_context,
        stat_loggers=factory,
345
        enable_log_requests=engine_args.enable_log_requests,
Alec's avatar
Alec committed
346
347
        disable_log_stats=engine_args.disable_log_stats,
    )
348
349

    logger.info(f"VllmWorker for {config.served_model_name} has been initialized")
350

351
    return engine_client, vllm_config, default_sampling_params, prometheus_temp_dir
Alec's avatar
Alec committed
352
353


354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
async def register_vllm_model(
    model_input: ModelInput,
    model_type: ModelType,
    generate_endpoint,
    config: Config,
    engine_client: AsyncLLM,
    vllm_config,
    migration_limit: int,
):
    """
    Helper function to register a vLLM model with runtime configuration.

    Args:
        model_input: Input type for the model (e.g., ModelInput.Tokens)
        model_type: Type of model (e.g., ModelType.Chat, ModelType.Prefill)
        generate_endpoint: Endpoint to register
        config: Configuration object
        engine_client: vLLM engine client
        vllm_config: vLLM configuration
        migration_limit: Migration limit for the model
    """
    runtime_config = ModelRuntimeConfig()

    # Get runtime configuration from vLLM engine
    logging.info(
        f"Getting engine runtime configuration metadata from vLLM engine for {model_type}..."
    )
    runtime_values = get_engine_cache_info(engine_client)
    runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
    runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
    runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
385
    runtime_config.enable_local_indexer = config.enable_local_indexer
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408

    # Add tool/reasoning parsers for decode models
    if model_type != ModelType.Prefill:
        runtime_config.tool_call_parser = config.tool_call_parser
        runtime_config.reasoning_parser = config.reasoning_parser

    # Get data_parallel_size from vllm_config (defaults to 1)
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    runtime_config.data_parallel_size = data_parallel_size

    await register_llm(
        model_input,
        model_type,
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
        migration_limit=migration_limit,
        runtime_config=runtime_config,
        custom_template_path=config.custom_jinja_template,
    )


Alec's avatar
Alec committed
409
410
411
412
async def init_prefill(runtime: DistributedRuntime, config: Config):
    """
    Instantiate and serve
    """
413
    component = runtime.namespace(config.namespace).component(config.component)
Alec's avatar
Alec committed
414
415
416
417

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")

418
419
420
421
422
423
    (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
    ) = setup_vllm_engine(config)
Alec's avatar
Alec committed
424

425
    handler = PrefillWorkerHandler(
426
427
428
429
430
        runtime,
        component,
        engine_client,
        default_sampling_params,
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
431
        enable_multimodal=config.enable_multimodal,
432
433
        generate_endpoint=generate_endpoint,
        config=config,
434
        use_vllm_tokenizer=config.use_vllm_tokenizer,
435
    )
436
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
437

438
439
440
441
442
443
444
445
446
447
448
449
450
451
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

Yan Ru Pei's avatar
Yan Ru Pei committed
452
    # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
453
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
454
    kv_publishers = setup_kv_event_publisher(
455
456
457
458
459
460
        config,
        component,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
461
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
462
463
    if kv_publishers:
        handler.kv_publishers = kv_publishers
464

465
    setup_metrics_collection(config, generate_endpoint, logger)
466

467
    # Register sleep/wake_up engine routes
468
    runtime.register_engine_route("sleep", handler.sleep)
469
470
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
471

472
473
474
475
476
    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return

477
    # Register prefill model with ModelType.Prefill
478
479
480
481
482
483
484
485
486
487
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
    await register_vllm_model(
        model_input,
        ModelType.Prefill,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
        migration_limit=0,  # Prefill doesn't support migration
    )
488

489
490
491
    health_check_payload = VllmPrefillHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
492

Alec's avatar
Alec committed
493
    try:
494
        logger.debug("Starting serve_endpoint for prefill worker")
Alec's avatar
Alec committed
495
        await asyncio.gather(
496
497
498
499
            # for prefill, we want to shutdown the engine after all prefill requests are finished because
            #     (temp reason): we don't support re-routing prefill requests
            #     (long-term reason): prefill engine should pull from a global queue so there is
            #                         only a few in-flight requests that can be quickly finished
500
501
502
            generate_endpoint.serve_endpoint(
                handler.generate,
                graceful_shutdown=True,
503
504
                # In practice config.served_model_name is always set, but mypy needs the "or" here.
                metrics_labels=[("model", config.served_model_name or config.model)],
505
                health_check_payload=health_check_payload,
506
507
            ),
            clear_endpoint.serve_endpoint(
508
509
                handler.clear_kv_blocks,
                metrics_labels=[("model", config.served_model_name)],
510
            ),
Alec's avatar
Alec committed
511
        )
512
        logger.debug("serve_endpoint completed for prefill worker")
Alec's avatar
Alec committed
513
514
515
516
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
517
        logger.debug("Cleaning up prefill worker")
Alec's avatar
Alec committed
518
519
520
521
522
523
524
525
        handler.cleanup()


async def init(runtime: DistributedRuntime, config: Config):
    """
    Instantiate and serve
    """

526
    component = runtime.namespace(config.namespace).component(config.component)
Alec's avatar
Alec committed
527
528
529

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")
530
531
532
    load_lora_endpoint = component.endpoint("load_lora")
    unload_lora_endpoint = component.endpoint("unload_lora")
    list_loras_endpoint = component.endpoint("list_loras")
Alec's avatar
Alec committed
533

534
535
536
    factory = StatLoggerFactory(
        component,
        config.engine_args.data_parallel_rank or 0,
537
        metrics_labels=[("model", config.served_model_name or config.model)],
538
    )
539
540
541
542
543
544
    (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
    ) = setup_vllm_engine(config, factory)
Alec's avatar
Alec committed
545

546
    # TODO Hack to get data, move this to registering in TBD
Alec's avatar
Alec committed
547
548
549
550
    factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
    factory.init_publish()

    handler = DecodeWorkerHandler(
551
552
553
554
        runtime,
        component,
        engine_client,
        default_sampling_params,
555
        getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
556
        enable_multimodal=config.enable_multimodal,
557
558
        generate_endpoint=generate_endpoint,
        config=config,
559
        use_vllm_tokenizer=config.use_vllm_tokenizer,
Alec's avatar
Alec committed
560
    )
561
    handler.add_temp_dir(prometheus_temp_dir)
Alec's avatar
Alec committed
562

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
    # Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
    consolidator_enabled = False
    consolidator_port = None

    if (
        hasattr(vllm_config, "consolidator_endpoints")
        and vllm_config.consolidator_endpoints
    ):
        # Extract connect endpoint (third element) for clients to subscribe
        # consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
        consolidator_output_endpoint = vllm_config.consolidator_endpoints[2]
        consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
        consolidator_enabled = True

    # Set up KV event publisher for prefix caching if enabled
    # If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
Yan Ru Pei's avatar
Yan Ru Pei committed
579
    kv_publishers = setup_kv_event_publisher(
580
581
582
583
584
585
        config,
        component,
        generate_endpoint,
        vllm_config,
        consolidator_enabled=consolidator_enabled,
        consolidator_port=consolidator_port,
Yan Ru Pei's avatar
Yan Ru Pei committed
586
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
587
588
    if kv_publishers:
        handler.kv_publishers = kv_publishers
589

590
    setup_metrics_collection(config, generate_endpoint, logger)
591

592
    # Register sleep/wake_up engine routes
593
    runtime.register_engine_route("sleep", handler.sleep)
594
595
    runtime.register_engine_route("wake_up", handler.wake_up)
    logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
596

597
598
599
600
    # Handle non-leader nodes - don't serve endpoints
    if config.engine_args.data_parallel_rank:
        await _handle_non_leader_node(config.engine_args.data_parallel_rank)
        return
601

602
603
604
    # Parse endpoint types from --dyn-endpoint-types flag
    model_type = parse_endpoint_types(config.dyn_endpoint_types)
    logger.info(f"Registering model with endpoint types: {config.dyn_endpoint_types}")
605

606
    model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
607

608
609
610
611
612
    # Warn if custom template provided but chat endpoint not enabled
    if config.custom_jinja_template and "chat" not in config.dyn_endpoint_types:
        logger.warning(
            "Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
            "The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
613
614
        )

615
616
617
618
619
620
621
622
623
624
    await register_vllm_model(
        model_input,
        model_type,
        generate_endpoint,
        config,
        engine_client,
        vllm_config,
        migration_limit=config.migration_limit,
    )

625
626
627
    health_check_payload = VllmHealthCheckPayload(
        engine_client, use_text_input=config.use_vllm_tokenizer
    ).to_dict()
628

Alec's avatar
Alec committed
629
    try:
630
        logger.debug("Starting serve_endpoint for decode worker")
Alec's avatar
Alec committed
631
        await asyncio.gather(
632
633
            # for decode, we want to transfer the in-flight requests to other decode engines,
            # because waiting them to finish can take a long time for long OSLs
634
635
            generate_endpoint.serve_endpoint(
                handler.generate,
636
                graceful_shutdown=config.migration_limit <= 0,
637
                metrics_labels=[("model", config.served_model_name or config.model)],
638
                health_check_payload=health_check_payload,
639
640
            ),
            clear_endpoint.serve_endpoint(
641
642
                handler.clear_kv_blocks,
                metrics_labels=[("model", config.served_model_name or config.model)],
643
            ),
644
645
646
647
648
649
650
651
652
653
654
655
            load_lora_endpoint.serve_endpoint(
                handler.load_lora,
                metrics_labels=[("model", config.served_model_name or config.model)],
            ),
            unload_lora_endpoint.serve_endpoint(
                handler.unload_lora,
                metrics_labels=[("model", config.served_model_name or config.model)],
            ),
            list_loras_endpoint.serve_endpoint(
                handler.list_loras,
                metrics_labels=[("model", config.served_model_name or config.model)],
            ),
Alec's avatar
Alec committed
656
        )
657
        logger.debug("serve_endpoint completed for decode worker")
Alec's avatar
Alec committed
658
659
660
661
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
662
        logger.debug("Cleaning up decode worker")
Alec's avatar
Alec committed
663
664
665
666
        # Cleanup background tasks
        handler.cleanup()


667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
def get_engine_cache_info(engine: AsyncLLM):
    """Retrieve cache configuration information from [`AsyncLLM`] engine."""

    try:
        # Get values directly from vllm_config instead of collective_rpc
        cache_values = {
            "num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks,
        }

        scheduler_values = {
            "max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs,
            "max_num_batched_tokens": engine.vllm_config.scheduler_config.max_num_batched_tokens,
        }

        logging.info(f"Cache config values: {cache_values}")
        logging.info(f"Scheduler config values: {scheduler_values}")
        return {
            "num_gpu_blocks": cache_values["num_gpu_blocks"],
            "max_num_seqs": scheduler_values["max_num_seqs"],
            "max_num_batched_tokens": scheduler_values["max_num_batched_tokens"],
        }
    except Exception as e:
        logging.error(f"Failed to get configuration values from vLLM config: {e}")
        raise


693
694
695
696
697
698
699
700
701
702
703
704
705
706
async def init_multimodal_processor(runtime: DistributedRuntime, config: Config):
    """Initialize multimodal processor component"""
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)

    # Get encode worker client
    encode_worker_client = (
        await runtime.namespace(config.namespace)
        .component("encoder")
        .endpoint("generate")
        .client()
    )

GuanLuo's avatar
GuanLuo committed
707
708
709
710
711
712
    pd_worker_client = (
        await runtime.namespace(config.namespace)
        .component("backend")
        .endpoint("generate")
        .client()
    )
713

GuanLuo's avatar
GuanLuo committed
714
    handler = PreprocessedHandler(
715
716
        config.engine_args,
        encode_worker_client,
GuanLuo's avatar
GuanLuo committed
717
        pd_worker_client,
718
719
720
721
722
723
724
    )

    logger.info("Waiting for Encoder Worker Instances ...")
    await encode_worker_client.wait_for_instances()

    # Register the endpoint as entrypoint to a model
    await register_llm(
GuanLuo's avatar
GuanLuo committed
725
        ModelInput.Tokens,
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
        ModelType.Chat,
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
    )

    logger.info("Starting to serve the processor endpoint...")

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
                handler.generate, metrics_labels=[("model", config.model)]
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
        handler.cleanup()


async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Config):
    """Initialize multimodal encode worker component"""
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)

    # Get PD worker client
    # In multimodal mode, the PD worker always registers as "backend"
    # (even in disaggregated mode with prefill/decode split, we still connect to "backend")
    pd_worker_client = (
        await runtime.namespace(config.namespace)
        .component("backend")
        .endpoint("generate")
        .client()
    )

    handler = EncodeWorkerHandler(
        config.engine_args,
        pd_worker_client,
    )
    await handler.async_init(runtime)
    logger.info("Waiting for PD Worker Instances ...")
    await pd_worker_client.wait_for_instances()
    logger.info("Starting to serve the encode worker endpoint...")

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
                handler.generate, metrics_labels=[("model", config.model)]
            ),
        )
    except Exception as e:
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
        logger.error(f"Failed to serve encode worker endpoint: {e}")
        raise
    finally:
        handler.cleanup()


async def init_vllm_native_encoder(runtime: DistributedRuntime, config: Config):
    """
    Initialize vLLM-native encoder worker component (ECConnector mode).
    In this mode, vLLM handles encoder execution, caching, and storage automatically.
    """
    # Create component and endpoint
    component = runtime.namespace(config.namespace).component(config.component)
    generate_endpoint = component.endpoint(config.endpoint)

    # Configure ECTransferConfig for producer role
    instance_id = 0
    engine_id = f"{config.namespace}.{config.component}.encoder.{instance_id}"

    # Configure encoder with producer role, it will be responsible for creating embeddings and storing them in the shared storage
    ec_transfer_config = create_ec_transfer_config(
        engine_id=engine_id,
        ec_role="ec_producer",
        ec_connector_backend=config.ec_connector_backend,
        ec_storage_path=config.ec_storage_path,
        ec_extra_config=config.ec_extra_config,
    )

    # Set ECTransferConfig on engine args
    config.engine_args.ec_transfer_config = ec_transfer_config

    # Setup vLLM engine
    (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
    ) = setup_vllm_engine(config)

    # Initialize vLLM Native Encoder Worker Handler
    handler = VLLMEncodeWorkerHandler(
        runtime,
        component,
        engine_client,
        config,
    )
    handler.add_temp_dir(prometheus_temp_dir)

    # 5. No async init needed - vLLM handles everything
    # await handler.async_init(runtime)  # Not needed for ECConnector mode

    logger.info("Starting to serve vLLM-native encoder endpoint...")

    # 6. Serve endpoint
    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
                handler.generate, metrics_labels=[("model", config.model)]
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve vLLM-native encoder endpoint: {e}")
        raise
    finally:
        handler.cleanup()


async def init_ec_processor(runtime: DistributedRuntime, config: Config):
    """
    Initialize ECConnector processor component.

    Simple processor that routes multimodal requests using ECConnector pattern:
    1. Preprocess request (same as regular processor)
    2. Send multimodal items to encoder workers (stores to shared storage)
    3. Forward preprocessed request to PD worker (loads from shared storage)
    4. Stream response back to client
    """
    # Create component and endpoint
    component = runtime.namespace(config.namespace).component(config.component)
    generate_endpoint = component.endpoint(config.endpoint)

    # Get encoder worker client
    encoder_client = (
        await runtime.namespace(config.namespace)
        .component("encoder")
        .endpoint("generate")
        .client()
    )

    # Get PD worker client
    pd_client = (
        await runtime.namespace(config.namespace)
        .component("backend")
        .endpoint("generate")
        .client()
    )

    # Get prompt template from args (must be passed via environment or command line)
    mm_prompt_template = config.mm_prompt_template

    # Create EC processor handler (with preprocessing like regular processor)
    handler = ECProcessorHandler(
        config.engine_args,
        encoder_worker_client=encoder_client,
        pd_worker_client=pd_client,
        prompt_template=mm_prompt_template,
    )

    logger.info("Waiting for encoder and PD worker instances...")
    await encoder_client.wait_for_instances()
    await pd_client.wait_for_instances()

892
    # Register the endpoint as entrypoint to a model (same as preprocessed_handler)
893
    await register_llm(
894
        ModelInput.Tokens,  # Use Rust tokenization for better performance and multi-image support
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
        ModelType.Chat,
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
    )

    logger.info("Starting to serve EC processor endpoint...")

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
                handler.generate, metrics_labels=[("model", config.model)]
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve EC processor endpoint: {e}")
912
913
914
915
916
917
        raise
    finally:
        handler.cleanup()


async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
918
919
920
921
922
923
    """
    Initialize multimodal worker component.

    Supports two modes:
    1. --multimodal-worker: Receives embeddings from separate encoder
    2. --multimodal-encode-prefill-worker: Handles inline encoding (e.g., Llama 4)
924

925
    Both can operate in aggregated (P+D) or disaggregated (P→D) mode.
926
927
928

    When --ec-consumer-mode is enabled, configures as ECConnector consumer
    to load encoder embeddings from shared storage.
929
    """
930
931
932
933
934
    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")

935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
    # Configure ECConnector consumer mode if enabled
    if config.ec_consumer_mode:
        logger.info("Configuring as ECConnector consumer for encoder embeddings")
        instance_id = 0
        engine_id = f"{config.namespace}.{config.component}.backend.{instance_id}"

        # The PD Worker just load the embeddings from the shared storage, so it is a consumer
        ec_transfer_config = create_ec_transfer_config(
            engine_id=engine_id,
            ec_role="ec_consumer",
            ec_connector_backend=config.ec_connector_backend,
            ec_storage_path=config.ec_storage_path,
            ec_extra_config=config.ec_extra_config,
        )

        # Set ECTransferConfig on engine args
        config.engine_args.ec_transfer_config = ec_transfer_config
        logger.info(f"Configured as ECConnector consumer with engine_id={engine_id}")

954
955
956
957
958
959
    (
        engine_client,
        vllm_config,
        default_sampling_params,
        prometheus_temp_dir,
    ) = setup_vllm_engine(config)
960

Ayush Agarwal's avatar
Ayush Agarwal committed
961
962
963
964
965
966
967
968
969
970
971
972
    # Set up decode worker client for disaggregated mode
    decode_worker_client = None
    if config.is_prefill_worker:
        # Prefill worker needs to connect to decode worker
        decode_worker_client = (
            await runtime.namespace(config.namespace)
            .component("decoder")
            .endpoint("generate")
            .client()
        )
        await decode_worker_client.wait_for_instances()
        logger.info("Connected to decode worker for disaggregated mode")
973

Ayush Agarwal's avatar
Ayush Agarwal committed
974
975
976
977
978
979
980
981
982
    # Choose handler based on worker type
    if config.multimodal_decode_worker:
        handler = MultimodalDecodeWorkerHandler(
            runtime, component, engine_client, config
        )
    else:
        handler = MultimodalPDWorkerHandler(
            runtime, component, engine_client, config, decode_worker_client
        )
983
    handler.add_temp_dir(prometheus_temp_dir)
984
985
986
987
988
989
990
991
992
993
994
995
996
997

    await handler.async_init(runtime)

    # Set up KV event publisher for prefix caching if enabled
    kv_publisher = setup_kv_event_publisher(
        config, component, generate_endpoint, vllm_config
    )
    if kv_publisher:
        handler.kv_publisher = kv_publisher

    metrics_labels = [("model", config.model)]
    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
998
999
                handler.generate,
                metrics_labels=metrics_labels,
1000
1001
            ),
            clear_endpoint.serve_endpoint(
1002
1003
                handler.clear_kv_blocks,
                metrics_labels=metrics_labels,
1004
1005
1006
1007
1008
1009
1010
1011
1012
            ),
        )
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
        handler.cleanup()


Alec's avatar
Alec committed
1013
1014
1015
1016
def main():
    uvloop.run(worker())


Alec's avatar
Alec committed
1017
if __name__ == "__main__":
Alec's avatar
Alec committed
1018
    main()