main.py 15.3 KB
Newer Older
Alec's avatar
Alec committed
1
2
3
4
5
6
7
8
9
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging
import os
import signal

import uvloop
10
from prometheus_client import REGISTRY
Alec's avatar
Alec committed
11
12
13
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
14
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
Alec's avatar
Alec committed
15

16
from dynamo.common.config_dump import dump_config
17
from dynamo.common.utils.prometheus import register_engine_metrics_callback
Alec's avatar
Alec committed
18
from dynamo.llm import (
19
    ModelInput,
20
    ModelRuntimeConfig,
Alec's avatar
Alec committed
21
22
23
    ModelType,
    ZmqKvEventPublisher,
    ZmqKvEventPublisherConfig,
24
    fetch_llm,
Alec's avatar
Alec committed
25
26
27
28
29
    register_llm,
)
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging

30
from .args import ENABLE_LMCACHE, Config, configure_ports, overwrite_args, parse_args
Alec's avatar
Alec committed
31
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
32
from .health_check import VllmHealthCheckPayload, VllmPrefillHealthCheckPayload
Alec's avatar
Alec committed
33
34
from .publisher import StatLoggerFactory

Alec's avatar
Alec committed
35
36
37
38
configure_dynamo_logging()
logger = logging.getLogger(__name__)


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def setup_lmcache_environment():
    """Setup LMCache environment variables for KV cache offloading"""
    # LMCache configuration for matching logic
    lmcache_config = {
        "LMCACHE_CHUNK_SIZE": "256",  # Token chunk size
        "LMCACHE_LOCAL_CPU": "True",  # Enable CPU memory backend
        "LMCACHE_MAX_LOCAL_CPU_SIZE": "20",  # CPU memory limit in GB
    }

    # Set environment variables
    for key, value in lmcache_config.items():
        if key not in os.environ:  # Only set if not already configured
            os.environ[key] = value
            logger.info(f"Set LMCache environment variable: {key}={value}")


Alec's avatar
Alec committed
55
56
async def graceful_shutdown(runtime):
    """
57
58
59
60
    Shutdown dynamo distributed runtime.
    The endpoints will be immediately invalidated so no new requests will be accepted.
    For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
    For endpoints served with graceful_shutdown=False, the serving function will return immediately.
Alec's avatar
Alec committed
61
62
63
64
65
66
67
68
69
70
    """
    logging.info("Received shutdown signal, shutting down DistributedRuntime")
    runtime.shutdown()
    logging.info("DistributedRuntime shutdown complete")


@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
    config = parse_args()

71
    await configure_ports(runtime, config)
72
73
    overwrite_args(config)

Alec's avatar
Alec committed
74
75
76
77
78
79
80
81
82
    # Set up signal handler for graceful shutdown
    loop = asyncio.get_running_loop()

    def signal_handler():
        asyncio.create_task(graceful_shutdown(runtime))

    for sig in (signal.SIGTERM, signal.SIGINT):
        loop.add_signal_handler(sig, signal_handler)

83
    logging.debug("Signal handlers set up for graceful shutdown")
Alec's avatar
Alec committed
84

85
    dump_config(config.dump_config_to, config)
86
87
88
89
90
91
92
93
94

    # Download the model if necessary.
    # register_llm would do this for us, but we want it on disk before we start vllm.
    # Ensure the original HF name (e.g. "Qwen/Qwen3-0.6B") is used as the served_model_name.
    if not config.served_model_name:
        config.served_model_name = config.engine_args.served_model_name = config.model
    if not os.path.exists(config.model):
        config.model = config.engine_args.model = await fetch_llm(config.model)

Alec's avatar
Alec committed
95
96
    if config.is_prefill_worker:
        await init_prefill(runtime, config)
97
        logger.debug("init_prefill completed")
Alec's avatar
Alec committed
98
99
    else:
        await init(runtime, config)
100
101
102
        logger.debug("init completed")

    logger.debug("Worker function completed, exiting...")
Alec's avatar
Alec committed
103
104


Yan Ru Pei's avatar
Yan Ru Pei committed
105
106
107
108
109
def setup_kv_event_publisher(
    config: Config,
    component,
    generate_endpoint,
    vllm_config,
Yan Ru Pei's avatar
Yan Ru Pei committed
110
):
Yan Ru Pei's avatar
Yan Ru Pei committed
111
    """
Yan Ru Pei's avatar
Yan Ru Pei committed
112
113
    Set up KV event publishers for prefix caching if enabled.
    Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Yan Ru Pei's avatar
Yan Ru Pei committed
114
115

    Returns:
Yan Ru Pei's avatar
Yan Ru Pei committed
116
        List of ZmqKvEventPublisher instances (one per dp_rank) if prefix caching is enabled, None otherwise.
Yan Ru Pei's avatar
Yan Ru Pei committed
117
118
119
120
    """
    if not config.engine_args.enable_prefix_caching:
        return None

121
122
123
124
125
    # Skip KV event publishing for decode workers
    if config.is_decode_worker:
        logger.info("Skipping KV event publisher setup for decode worker")
        return None

Yan Ru Pei's avatar
Yan Ru Pei committed
126
127
128
129
130
131
132
133
134
135
136
137
    # Get data_parallel_size to create publishers for all dp_ranks
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    kv_publishers = []

    for dp_rank in range(data_parallel_size):
        # Each dp_rank publishes to a different port
        zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
            config.engine_args.kv_events_config.endpoint,
            data_parallel_rank=dp_rank,
        ).replace("*", "127.0.0.1")

        zmq_config = ZmqKvEventPublisherConfig(
138
            worker_id=generate_endpoint.connection_id(),
Yan Ru Pei's avatar
Yan Ru Pei committed
139
140
141
142
143
            kv_block_size=vllm_config.cache_config.block_size,
            zmq_endpoint=zmq_endpoint,
        )
        kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
        kv_publishers.append(kv_publisher)
Yan Ru Pei's avatar
Yan Ru Pei committed
144

Yan Ru Pei's avatar
Yan Ru Pei committed
145
146
147
        logger.info(
            f"Worker reading KV events for dp_rank={dp_rank} from {zmq_endpoint}"
        )
Yan Ru Pei's avatar
Yan Ru Pei committed
148

Yan Ru Pei's avatar
Yan Ru Pei committed
149
    return kv_publishers if kv_publishers else None
Yan Ru Pei's avatar
Yan Ru Pei committed
150
151


Alec's avatar
Alec committed
152
def setup_vllm_engine(config, stat_logger=None):
153
154
155
156
157
    setup_multiprocess_prometheus()
    logger.debug(
        f"Prometheus multiproc dir set to: {os.environ.get('PROMETHEUS_MULTIPROC_DIR')}"
    )

Alec's avatar
Alec committed
158
159
160
161
    os.environ["VLLM_NO_USAGE_STATS"] = "1"  # Avoid internal HTTP requests
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

    engine_args = config.engine_args
162
163
164
165
166
167

    # KV transfer config is now handled by args.py based on ENABLE_LMCACHE env var
    if ENABLE_LMCACHE:
        setup_lmcache_environment()
        logger.info("LMCache enabled for VllmWorker")
    else:
168
        logger.debug("LMCache is disabled")
169

Alec's avatar
Alec committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    # Load default sampling params from `generation_config.json`
    default_sampling_params = (
        engine_args.create_model_config().get_diff_sampling_param()
    )

    # Taken from build_async_engine_client_from_engine_args()
    usage_context = UsageContext.OPENAI_API_SERVER
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

    factory = []
    if stat_logger:
        factory.append(stat_logger)

    engine_client = AsyncLLM.from_vllm_config(
        vllm_config=vllm_config,
        usage_context=usage_context,
        stat_loggers=factory,
        disable_log_requests=engine_args.disable_log_requests,
        disable_log_stats=engine_args.disable_log_stats,
    )
190
    if ENABLE_LMCACHE:
191
192
193
        logger.info(
            f"VllmWorker for {config.served_model_name} has been initialized with LMCache"
        )
194
    else:
195
        logger.info(f"VllmWorker for {config.served_model_name} has been initialized")
196

Alec's avatar
Alec committed
197
198
199
    return engine_client, vllm_config, default_sampling_params


200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
async def register_vllm_model(
    model_input: ModelInput,
    model_type: ModelType,
    generate_endpoint,
    config: Config,
    engine_client: AsyncLLM,
    vllm_config,
    migration_limit: int,
):
    """
    Helper function to register a vLLM model with runtime configuration.

    Args:
        model_input: Input type for the model (e.g., ModelInput.Tokens)
        model_type: Type of model (e.g., ModelType.Chat, ModelType.Prefill)
        generate_endpoint: Endpoint to register
        config: Configuration object
        engine_client: vLLM engine client
        vllm_config: vLLM configuration
        migration_limit: Migration limit for the model
    """
    runtime_config = ModelRuntimeConfig()

    # Get runtime configuration from vLLM engine
    logging.info(
        f"Getting engine runtime configuration metadata from vLLM engine for {model_type}..."
    )
    runtime_values = get_engine_cache_info(engine_client)
    runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
    runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
    runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]

    # Add tool/reasoning parsers for decode models
    if model_type != ModelType.Prefill:
        runtime_config.tool_call_parser = config.tool_call_parser
        runtime_config.reasoning_parser = config.reasoning_parser

    # Get data_parallel_size from vllm_config (defaults to 1)
    data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
    runtime_config.data_parallel_size = data_parallel_size

    await register_llm(
        model_input,
        model_type,
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
        migration_limit=migration_limit,
        runtime_config=runtime_config,
        custom_template_path=config.custom_jinja_template,
    )


Alec's avatar
Alec committed
254
255
256
257
async def init_prefill(runtime: DistributedRuntime, config: Config):
    """
    Instantiate and serve
    """
258
    component = runtime.namespace(config.namespace).component(config.component)
Alec's avatar
Alec committed
259
260
261
262
263
    await component.create_service()

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")

Yan Ru Pei's avatar
Yan Ru Pei committed
264
    engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)
Alec's avatar
Alec committed
265

266
267
268
    handler = PrefillWorkerHandler(
        runtime, component, engine_client, default_sampling_params
    )
Alec's avatar
Alec committed
269

Yan Ru Pei's avatar
Yan Ru Pei committed
270
271
    # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
    kv_publishers = setup_kv_event_publisher(
272
273
        config, component, generate_endpoint, vllm_config
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
274
275
    if kv_publishers:
        handler.kv_publishers = kv_publishers
276

277
278
279
280
281
282
283
284
285
286
287
288
    # Register prefill model with ModelType.Prefill
    if not config.engine_args.data_parallel_rank:  # if rank is 0 or None then register
        await register_vllm_model(
            ModelInput.Tokens,
            ModelType.Prefill,
            generate_endpoint,
            config,
            engine_client,
            vllm_config,
            migration_limit=0,  # Prefill doesn't support migration
        )

289
    health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict()
290

Alec's avatar
Alec committed
291
    try:
292
        logger.debug("Starting serve_endpoint for prefill worker")
Alec's avatar
Alec committed
293
        await asyncio.gather(
294
295
296
297
            # for prefill, we want to shutdown the engine after all prefill requests are finished because
            #     (temp reason): we don't support re-routing prefill requests
            #     (long-term reason): prefill engine should pull from a global queue so there is
            #                         only a few in-flight requests that can be quickly finished
298
299
300
            generate_endpoint.serve_endpoint(
                handler.generate,
                graceful_shutdown=True,
301
302
                # In practice config.served_model_name is always set, but mypy needs the "or" here.
                metrics_labels=[("model", config.served_model_name or config.model)],
303
                health_check_payload=health_check_payload,
304
305
            ),
            clear_endpoint.serve_endpoint(
306
307
                handler.clear_kv_blocks,
                metrics_labels=[("model", config.served_model_name)],
308
            ),
Alec's avatar
Alec committed
309
        )
310
        logger.debug("serve_endpoint completed for prefill worker")
Alec's avatar
Alec committed
311
312
313
314
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
315
        logger.debug("Cleaning up prefill worker")
Alec's avatar
Alec committed
316
317
318
319
320
321
322
323
        handler.cleanup()


async def init(runtime: DistributedRuntime, config: Config):
    """
    Instantiate and serve
    """

324
    component = runtime.namespace(config.namespace).component(config.component)
Alec's avatar
Alec committed
325
326
327
328
329
    await component.create_service()

    generate_endpoint = component.endpoint(config.endpoint)
    clear_endpoint = component.endpoint("clear_kv_blocks")

330
331
332
    factory = StatLoggerFactory(
        component,
        config.engine_args.data_parallel_rank or 0,
333
        metrics_labels=[("model", config.served_model_name or config.model)],
334
    )
Alec's avatar
Alec committed
335
336
337
338
    engine_client, vllm_config, default_sampling_params = setup_vllm_engine(
        config, factory
    )

339
    # TODO Hack to get data, move this to registering in TBD
Alec's avatar
Alec committed
340
341
342
343
344
    factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
    factory.set_request_total_slots_all(vllm_config.scheduler_config.max_num_seqs)
    factory.init_publish()

    handler = DecodeWorkerHandler(
345
346
347
348
        runtime,
        component,
        engine_client,
        default_sampling_params,
Alec's avatar
Alec committed
349
350
    )

Yan Ru Pei's avatar
Yan Ru Pei committed
351
352
    # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
    kv_publishers = setup_kv_event_publisher(
Yan Ru Pei's avatar
Yan Ru Pei committed
353
354
        config, component, generate_endpoint, vllm_config
    )
Yan Ru Pei's avatar
Yan Ru Pei committed
355
356
    if kv_publishers:
        handler.kv_publishers = kv_publishers
357

358
    if config.engine_args.disable_log_stats is False:
359
360
361
        register_engine_metrics_callback(
            endpoint=generate_endpoint, registry=REGISTRY, metric_prefix_filter="vllm:"
        )
362

363
    if not config.engine_args.data_parallel_rank:  # if rank is 0 or None then register
364
        await register_vllm_model(
365
366
            ModelInput.Tokens,
            ModelType.Chat | ModelType.Completions,
367
            generate_endpoint,
368
369
370
            config,
            engine_client,
            vllm_config,
371
372
373
            migration_limit=config.migration_limit,
        )

374
    health_check_payload = VllmHealthCheckPayload(engine_client).to_dict()
375

Alec's avatar
Alec committed
376
    try:
377
        logger.debug("Starting serve_endpoint for decode worker")
Alec's avatar
Alec committed
378
        await asyncio.gather(
379
380
            # for decode, we want to transfer the in-flight requests to other decode engines,
            # because waiting them to finish can take a long time for long OSLs
381
382
            generate_endpoint.serve_endpoint(
                handler.generate,
383
                graceful_shutdown=config.migration_limit <= 0,
384
                metrics_labels=[("model", config.served_model_name or config.model)],
385
                health_check_payload=health_check_payload,
386
387
            ),
            clear_endpoint.serve_endpoint(
388
389
                handler.clear_kv_blocks,
                metrics_labels=[("model", config.served_model_name or config.model)],
390
            ),
Alec's avatar
Alec committed
391
        )
392
        logger.debug("serve_endpoint completed for decode worker")
Alec's avatar
Alec committed
393
394
395
396
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
397
        logger.debug("Cleaning up decode worker")
Alec's avatar
Alec committed
398
399
400
401
        # Cleanup background tasks
        handler.cleanup()


402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def get_engine_cache_info(engine: AsyncLLM):
    """Retrieve cache configuration information from [`AsyncLLM`] engine."""

    try:
        # Get values directly from vllm_config instead of collective_rpc
        cache_values = {
            "num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks,
        }

        scheduler_values = {
            "max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs,
            "max_num_batched_tokens": engine.vllm_config.scheduler_config.max_num_batched_tokens,
        }

        logging.info(f"Cache config values: {cache_values}")
        logging.info(f"Scheduler config values: {scheduler_values}")
        return {
            "num_gpu_blocks": cache_values["num_gpu_blocks"],
            "max_num_seqs": scheduler_values["max_num_seqs"],
            "max_num_batched_tokens": scheduler_values["max_num_batched_tokens"],
        }
    except Exception as e:
        logging.error(f"Failed to get configuration values from vLLM config: {e}")
        raise


Alec's avatar
Alec committed
428
429
430
431
def main():
    uvloop.run(worker())


Alec's avatar
Alec committed
432
if __name__ == "__main__":
Alec's avatar
Alec committed
433
    main()