core.py 73.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
7
import threading
import time
8
from collections import defaultdict, deque
9
from collections.abc import Callable, Generator
10
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
11
from contextlib import ExitStack, contextmanager
12
from functools import partial
13
from inspect import isclass, signature
14
from logging import DEBUG
15
from typing import Any, TypeVar, cast
16

17
import msgspec
18
19
import zmq

20
21
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
22
from vllm.envs import enable_envs_cache
23
from vllm.logger import init_logger
24
from vllm.logging_utils.dump_input import dump_engine_exception
25
from vllm.lora.request import LoRARequest
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.tasks import POOLING_TASKS, SupportedTask
28
from vllm.tracing import instrument, maybe_init_worker_tracer
29
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
30
31
32
33
from vllm.utils.gc_utils import (
    freeze_gc_heap,
    maybe_attach_gc_debug_callback,
)
34
from vllm.utils.hashing import get_hash_fn_by_name
35
from vllm.utils.network_utils import make_zmq_socket
36
from vllm.utils.system_utils import decorate_logs, set_process_title
37
38
39
40
41
42
43
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
44
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
45
from vllm.v1.core.sched.output import SchedulerOutput
46
from vllm.v1.engine import (
47
    EngineCoreOutput,
48
49
50
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
51
    FinishReason,
52
    PauseMode,
53
54
55
56
57
58
59
60
61
62
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
63
from vllm.v1.executor import Executor
64
from vllm.v1.kv_cache_interface import KVCacheConfig
65
from vllm.v1.metrics.stats import SchedulerStats
66
from vllm.v1.outputs import ModelRunnerOutput
67
from vllm.v1.request import Request, RequestStatus
68
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
69
from vllm.v1.structured_output import StructuredOutputManager
70
from vllm.v1.utils import compute_iteration_details
71
72
73
74
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

75
HANDSHAKE_TIMEOUT_MINS = 5
76

77
_R = TypeVar("_R")  # Return type for collective_rpc
78

79
80
81
82

class EngineCore:
    """Inner loop of vLLM's Engine."""

83
84
85
86
87
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
88
        executor_fail_callback: Callable | None = None,
89
        include_finished_set: bool = False,
90
    ):
91
92
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
93

94
95
        load_general_plugins()

96
        self.vllm_config = vllm_config
97
        if not vllm_config.parallel_config.data_parallel_rank_local:
98
99
100
101
102
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
103

104
105
        self.log_stats = log_stats

106
107
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
108
        if executor_fail_callback is not None:
109
            self.model_executor.register_failure_callback(executor_fail_callback)
110

111
112
        self.available_gpu_memory_for_kv_cache = -1

113
        # Setup KV Caches and update CacheConfig after profiling.
114
115
116
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
117

118
119
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
120
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
121

122
123
        self.structured_output_manager = StructuredOutputManager(vllm_config)

124
        # Setup scheduler.
125
        Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
126

127
        if len(kv_cache_config.kv_cache_groups) == 0:  # noqa: SIM102
128
129
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
130
131
132
            if vllm_config.scheduler_config.enable_chunked_prefill:
                logger.warning("Disabling chunked prefill for model without KVCache")
                vllm_config.scheduler_config.enable_chunked_prefill = False
133

134
135
136
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
137
            * vllm_config.parallel_config.prefill_context_parallel_size
138
139
        )

140
        self.scheduler: SchedulerInterface = Scheduler(
141
            vllm_config=vllm_config,
142
143
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
144
            include_finished_set=include_finished_set,
145
            log_stats=self.log_stats,
146
            block_size=scheduler_block_size,
147
        )
148
        self.use_spec_decode = vllm_config.speculative_config is not None
149
        if self.scheduler.connector is not None:  # type: ignore
150
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
151

152
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
153
154
        self.mm_receiver_cache = mm_registry.engine_receiver_cache_from_config(
            vllm_config
155
        )
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        # If a KV connector is initialized for scheduler, we want to collect
        # handshake metadata from all workers so the connector in the scheduler
        # will have the full context
        kv_connector = self.scheduler.get_kv_connector()
        if kv_connector is not None:
            # Collect and store KV connector xfer metadata from workers
            # (after KV cache registration)
            xfer_handshake_metadata = (
                self.model_executor.get_kv_connector_handshake_metadata()
            )

            if xfer_handshake_metadata:
                # xfer_handshake_metadata is list of dicts from workers
                # Each dict already has structure {tp_rank: metadata}
                # Merge all worker dicts into a single dict
                content: dict[int, Any] = {}
                for worker_dict in xfer_handshake_metadata:
                    if worker_dict is not None:
                        content.update(worker_dict)
                kv_connector.set_xfer_handshake_metadata(content)

178
179
180
181
182
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
183
        self.batch_queue: (
184
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput, Future[Any]]] | None
185
        ) = None
186
        if self.batch_queue_size > 1:
187
            logger.debug("Batch queue is enabled with size %d", self.batch_queue_size)
188
            self.batch_queue = deque(maxlen=self.batch_queue_size)
189

190
        self.is_ec_producer = (
191
192
193
            vllm_config.ec_transfer_config is not None
            and vllm_config.ec_transfer_config.is_ec_producer
        )
194
        self.is_pooling_model = vllm_config.model_config.runner_type == "pooling"
195

196
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
197
        if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
198
            caching_hash_fn = get_hash_fn_by_name(
199
200
                vllm_config.cache_config.prefix_caching_hash_algo
            )
201
202
203
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
204
                scheduler_block_size, caching_hash_fn
205
            )
206

207
208
209
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
210
        self.async_scheduling = vllm_config.scheduler_config.async_scheduling
211

212
        self.aborts_queue = queue.Queue[list[str]]()
213

214
        self._idle_state_callbacks: list[Callable] = []
215

216
217
218
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        freeze_gc_heap()
219
220
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()
221
222
223
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()
224

225
    @instrument(span_name="Prepare model")
226
    def _initialize_kv_caches(
227
228
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
229
        start = time.time()
230

231
        # Get all kv cache needed by the model
232
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
233

234
235
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
236
237
238
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
239
                self.available_gpu_memory_for_kv_cache = (
240
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
241
242
243
244
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
245
246
247
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
248
249
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
250
251
252
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
253

254
        assert len(kv_cache_specs) == len(available_gpu_memory)
255

256
257
258
        # Track max_model_len before KV cache config to detect auto-fit changes
        max_model_len_before = vllm_config.model_config.max_model_len

259
260
261
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
262
263
264
265
266
267
268
269

        # If auto-fit reduced max_model_len, sync the new value to workers.
        # This is needed because workers were spawned before memory profiling
        # and have the original (larger) max_model_len cached.
        max_model_len_after = vllm_config.model_config.max_model_len
        if max_model_len_after != max_model_len_before:
            self.collective_rpc("update_max_model_len", args=(max_model_len_after,))

270
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
271
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
272
        num_cpu_blocks = 0
273
274

        # Initialize kv cache and warmup the execution
275
        self.model_executor.initialize_from_config(kv_cache_configs)
276

277
        elapsed = time.time() - start
278
        logger.info_once(
279
            "init engine (profile, create kv cache, warmup model) took %.2f seconds",
280
            elapsed,
281
            scope="local",
282
        )
283
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
284

285
286
287
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

288
289
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
290

291
292
293
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
294
295
296
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
297
298
                f"request_id must be a string, got {type(request.request_id)}"
            )
299

300
        if pooling_params := request.pooling_params:
301
            supported_pooling_tasks = [
302
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
303
304
            ]

305
            if pooling_params.task not in supported_pooling_tasks:
306
307
308
309
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
310

311
        if request.kv_transfer_params is not None and (
312
313
314
315
316
317
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
318

319
        self.scheduler.add_request(request)
320

321
    def abort_requests(self, request_ids: list[str]):
322
323
324
325
326
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
327
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
328

329
330
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
331
        """Execute the model and log detailed info on failure."""
332
        try:
333
            yield
334
335
336
337
338
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

339
            # NOTE: This method is exception-free
340
341
342
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
343
344
            raise err

345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    @contextmanager
    def log_iteration_details(self, scheduler_output: SchedulerOutput):
        if not self.vllm_config.observability_config.enable_logging_iteration_details:
            yield
            return
        self._iteration_index = getattr(self, "_iteration_index", 0)
        iteration_details = compute_iteration_details(scheduler_output)
        before = time.monotonic()
        yield
        logger.info(
            "".join(
                [
                    "Iteration(",
                    str(self._iteration_index),
                    "): ",
                    str(iteration_details.num_ctx_requests),
                    " context requests, ",
                    str(iteration_details.num_ctx_tokens),
                    " context tokens, ",
                    str(iteration_details.num_generation_requests),
                    " generation requests, ",
                    str(iteration_details.num_generation_tokens),
                    " generation tokens, iteration elapsed time: ",
                    format((time.monotonic() - before) * 1000, ".2f"),
                    " ms",
                ]
            )
        )
        self._iteration_index += 1

375
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
376
377
378
379
380
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
381

382
383
384
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
385
            return {}, False
386
387
388
        scheduler_output = self.scheduler.schedule()
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
389
390
391
392
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
393
394
395
396
            model_output = future.result()
            if model_output is None:
                model_output = self.model_executor.sample_tokens(grammar_output)

397
398
399
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
400
401
402
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
403

404
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
405

406
    def post_step(self, model_executed: bool) -> None:
407
408
409
410
        # When using async scheduling we can't get draft token ids in advance,
        # so we update draft token ids in the worker process and don't
        # need to update draft token ids here.
        if not self.async_scheduling and self.use_spec_decode and model_executed:
411
412
413
414
415
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

416
    def step_with_batch_queue(
417
        self,
418
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
419
420
421
422
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
423
424
425
426
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
427
428
429
430
431
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
432

433
434
        batch_queue = self.batch_queue
        assert batch_queue is not None
435

436
437
438
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
439
        assert len(batch_queue) < self.batch_queue_size
440

441
        model_executed = False
442
        deferred_scheduler_output = None
443
        if self.scheduler.has_requests():
444
445
446
447
            scheduler_output = self.scheduler.schedule()
            exec_future = self.model_executor.execute_model(
                scheduler_output, non_block=True
            )
448
            if not self.is_ec_producer:
449
                model_executed = scheduler_output.total_num_scheduled_tokens > 0
450

451
            if self.is_pooling_model or not model_executed:
452
453
                # No sampling required (no requests scheduled).
                future = cast(Future[ModelRunnerOutput], exec_future)
454
            else:
455
456
457
                if not scheduler_output.pending_structured_output_tokens:
                    # We aren't waiting for any tokens, get any grammar output
                    # and sample immediately.
458
459
460
                    grammar_output = self.scheduler.get_grammar_bitmask(
                        scheduler_output
                    )
461
462
463
                    future = self.model_executor.sample_tokens(
                        grammar_output, non_block=True
                    )
464
                else:
465
466
467
468
469
                    # We need to defer sampling until we have processed the model output
                    # from the prior step.
                    deferred_scheduler_output = scheduler_output

            if not deferred_scheduler_output:
470
                # Add this step's future to the queue.
471
                batch_queue.appendleft((future, scheduler_output, exec_future))
472
473
474
475
476
477
478
479
                if (
                    model_executed
                    and len(batch_queue) < self.batch_queue_size
                    and not batch_queue[-1][0].done()
                ):
                    # Don't block on next worker response unless the queue is full
                    # or there are no more requests to schedule.
                    return None, True
480
481
482
483
484
485

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False
486
487

        # Block until the next result is available.
488
        future, scheduler_output, exec_model_fut = batch_queue.pop()
489
490
491
492
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
493
            model_output = future.result()
494
495
496
497
498
            if model_output is None:
                # None from sample_tokens() implies that the original execute_model()
                # call failed - raise that exception.
                exec_model_fut.result()
                raise RuntimeError("unexpected error")
499

500
501
502
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
503
504
505
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
506
507
508
509
510

        # NOTE(nick): We can either handle the deferred tasks here or save
        # in a field and do it immediately once step_with_batch_queue is
        # re-called. The latter slightly favors TTFT over TPOT/throughput.
        if deferred_scheduler_output:
511
512
513
514
515
516
517
518
519
520
521
522
            # If we are doing speculative decoding with structured output,
            # we need to get the draft token ids from the prior step before
            # we can compute the grammar bitmask for the deferred request.
            if self.use_spec_decode:
                draft_token_ids = self.model_executor.take_draft_token_ids()
                assert draft_token_ids is not None
                # Update the draft token ids in the scheduler output to
                # filter out the invalid spec tokens, which will be padded
                # with -1 and skipped by the grammar bitmask computation.
                self.scheduler.update_draft_token_ids_in_output(
                    draft_token_ids, deferred_scheduler_output
                )
523
524
525
526
527
528
            # We now have the tokens needed to compute the bitmask for the
            # deferred request. Get the bitmask and call sample tokens.
            grammar_output = self.scheduler.get_grammar_bitmask(
                deferred_scheduler_output
            )
            future = self.model_executor.sample_tokens(grammar_output, non_block=True)
529
            batch_queue.appendleft((future, deferred_scheduler_output, exec_future))
530

531
        return engine_core_outputs, model_executed
532

533
534
535
536
537
    def _process_aborts_queue(self):
        if not self.aborts_queue.empty():
            request_ids = []
            while not self.aborts_queue.empty():
                ids = self.aborts_queue.get_nowait()
538
539
                # Should be a list here, but also handle string just in case.
                request_ids.extend((ids,) if isinstance(ids, str) else ids)
540
541
542
            # More efficient to abort all as a single batch.
            self.abort_requests(request_ids)

543
    def shutdown(self):
544
        self.structured_output_manager.clear_backend()
545
546
        if self.model_executor:
            self.model_executor.shutdown()
547
548
        if self.scheduler:
            self.scheduler.shutdown()
549

550
551
    def profile(self, is_start: bool = True, profile_prefix: str | None = None):
        self.model_executor.profile(is_start, profile_prefix)
552

553
554
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
555
        # re-sync the internal caches (P0 sender, P1 receiver)
556
        if self.scheduler.has_unfinished_requests():
557
558
559
560
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
561

562
        # The cache either exists in EngineCore or WorkerWrapperBase
563
564
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
565

566
567
        self.model_executor.reset_mm_cache()

568
569
570
571
572
573
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.scheduler.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
574

575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    def reset_encoder_cache(self) -> None:
        """Reset the encoder cache to invalidate all cached encoder outputs.

        This should be called when model weights are updated to ensure
        stale vision embeddings computed with old weights are not reused.
        Clears both the scheduler's cache manager and the GPU model runner's cache.
        """
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 sender, P1 receiver)
        if self.scheduler.has_unfinished_requests():
            logger.warning(
                "Resetting the encoder cache when requests are "
                "in progress may lead to desynced internal caches."
            )

        # Reset the scheduler's encoder cache manager (logical state)
        self.scheduler.reset_encoder_cache()
        # Reset the GPU model runner's encoder cache (physical storage)
        self.model_executor.reset_encoder_cache()

595
596
597
598
599
    def _reset_caches(self, reset_running_requests=True) -> None:
        self.reset_prefix_cache(reset_running_requests=reset_running_requests)
        self.reset_mm_cache()
        self.reset_encoder_cache()

600
601
    def pause_scheduler(
        self, mode: PauseMode = "abort", clear_cache: bool = True
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
    ) -> Future | None:
        """Pause generation; behavior depends on mode.

        All pause modes queue new adds -- "abort" and "keep" skip step();
        "wait" allows step() so in-flight requests can drain.

        - ``abort``: Set PAUSED_NEW, abort all requests, wait for abort
          outputs to be sent (when running with output_queue), optionally
          clear caches, then complete the returned Future.
        - ``wait``: Set PAUSED_NEW (queue adds, keep stepping); when drained,
          optionally clear caches, then complete the returned Future.
        - ``keep``: Set PAUSED_ALL; return a Future that completes when the
          output queue is empty.
        """
        if mode not in ("keep", "abort", "wait"):
            raise ValueError(f"Invalid pause mode: {mode}")
        if mode == "wait":
            raise ValueError("'wait' mode can't be used in inproc-engine mode")

        if mode == "abort":
            self.scheduler.finish_requests(None, RequestStatus.FINISHED_ABORTED)

        pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
        self.scheduler.set_pause_state(pause_state)
        if clear_cache:
            self._reset_caches()

629
630
631
        return None

    def resume_scheduler(self) -> None:
632
633
        """Resume the scheduler and flush any requests queued while paused."""
        self.scheduler.set_pause_state(PauseState.UNPAUSED)
634
635

    def is_scheduler_paused(self) -> bool:
636
637
        """Return whether the scheduler is in any pause state."""
        return self.scheduler.pause_state != PauseState.UNPAUSED
638

639
    def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None | Future:
640
641
642
643
644
645
646
647
        """Put the engine to sleep at the specified level.

        Args:
            level: Sleep level.
                - Level 0: Pause scheduling only. Requests are still accepted
                           but not processed. No GPU memory changes.
                - Level 1: Offload model weights to CPU, discard KV cache.
                - Level 2: Discard all GPU memory.
648
649
            mode: Pause mode - how to deal with any existing requests, see
                documentation of pause_scheduler method.
650
        """
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675

        # Pause scheduler before sleeping.
        clear_prefix_cache = level >= 1
        pause_future = self.pause_scheduler(mode=mode, clear_cache=clear_prefix_cache)
        if level < 1:
            return pause_future

        # Level 1+: Delegate to executor for GPU memory management
        model_executor = self.model_executor
        if pause_future is None:
            model_executor.sleep(level)
            return None

        future = Future[Any]()

        def pause_complete(f: Future):
            try:
                f.result()  # propagate any exception
                future.set_result(model_executor.sleep(level))
            except Exception as e:
                future.set_exception(e)

        logger.info("Waiting for in-flight requests to complete before sleeping...")
        pause_future.add_done_callback(pause_complete)
        return future
676

677
    def wake_up(self, tags: list[str] | None = None):
678
679
680
681
682
683
        """Wake up the engine from sleep.

        Args:
            tags: Tags to wake up. Use ["scheduling"] for level 0 wake up.
        """
        if tags is not None and "scheduling" in tags:
684
685
686
687
            # Remove "scheduling" from tags if there are other tags to process.
            tags = [t for t in tags if t != "scheduling"]

        if tags is None or tags:
688
            self.model_executor.wake_up(tags)
689

690
691
692
        # Resume scheduling (applies to all levels)
        self.resume_scheduler()

693
    def is_sleeping(self) -> bool:
694
        """Check if engine is sleeping at any level."""
695
        return self.is_scheduler_paused() or self.model_executor.is_sleeping
696

697
    def execute_dummy_batch(self):
698
        self.model_executor.execute_dummy_batch()
699

700
701
702
703
704
705
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

706
    def list_loras(self) -> set[int]:
707
708
709
710
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
711

712
713
714
    def save_sharded_state(
        self,
        path: str,
715
716
        pattern: str | None = None,
        max_size: int | None = None,
717
    ) -> None:
718
719
720
721
722
723
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
724
725
        method: str | Callable[..., _R],
        timeout: float | None = None,
726
        args: tuple = (),
727
        kwargs: dict[str, Any] | None = None,
728
729
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
730

731
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
732
        """Preprocess the request.
733

734
735
736
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
737
738
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
739
        # and will only be accessed in the input processing thread afterwards.
740
        if self.mm_receiver_cache is not None and request.mm_features:
741
742
743
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
744

745
        req = Request.from_engine_core_request(request, self.request_block_hasher)
746
747
748
749
750
751
752
753
754
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

755
756
757
758

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

759
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
760

761
    @instrument(span_name="EngineCoreProc init")
762
763
    def __init__(
        self,
764
        vllm_config: VllmConfig,
765
        local_client: bool,
766
        handshake_address: str,
767
        executor_class: type[Executor],
768
        log_stats: bool,
769
        client_handshake_address: str | None = None,
770
        *,
771
        engine_index: int = 0,
772
    ):
Rui Qiao's avatar
Rui Qiao committed
773
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
774
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
775
        executor_fail_callback = lambda: self.input_queue.put_nowait(
776
777
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
778

Rui Qiao's avatar
Rui Qiao committed
779
780
781
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
782

783
784
785
786
787
788
789
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
790
            self.client_count = len(addresses.outputs)
791
792

            # Set up data parallel environment.
793
            self.has_coordinator = addresses.coordinator_output is not None
794
            self.frontend_stats_publish_address = (
795
796
797
798
799
800
801
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
802
            internal_dp_balancing = (
803
                self.has_coordinator
804
805
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
806
807
808
            # Only publish request queue stats to coordinator for "internal"
            # and "hybrid" LB modes.
            self.publish_dp_lb_stats = internal_dp_balancing
809

810
811
            self._init_data_parallel(vllm_config)

812
            super().__init__(
813
814
815
816
817
                vllm_config,
                executor_class,
                log_stats,
                executor_fail_callback,
                internal_dp_balancing,
818
            )
819

820
821
822
823
824
825
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
826
827
828
829
830
831
832
833
834
835
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
836
837
838
839
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
840
841
842
843
844
845
846
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
847
848
849
850
851
852
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
853
                    raise RuntimeError("Input socket thread died during startup")
854
855
856
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

Rui Qiao's avatar
Rui Qiao committed
857
    @contextmanager
858
859
860
861
862
863
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
864
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
865
    ) -> Generator[EngineZmqAddresses, None, None]:
866
867
868
869
870
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

871
        For DP>1 with internal load-balancing this is with the shared front-end
872
873
        process which may reside on a different node.

874
        For DP>1 with external or hybrid load-balancing, two handshakes are
875
        performed:
876
877
878
879
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
880
881
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
882
883
884
885
886
887

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
888
        input_ctx = zmq.Context()
889
        is_local = local_client and client_handshake_address is None
890
        headless = not local_client
891
892
893
894
895
896
897
898
899
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
900
901
902
903
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
904
            assert local_client
905
            local_handshake = self._perform_handshake(
906
907
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
923
        headless: bool,
924
        vllm_config: VllmConfig,
925
        parallel_config_to_update: ParallelConfig | None = None,
926
    ) -> Generator[EngineZmqAddresses, None, None]:
927
928
929
930
931
932
933
934
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
935
            # Register engine with front-end.
936
937
938
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
939
940
941
942
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
943
944
945
946
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
947
948
949
950
951
952
953
954
955
956
957
958

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
959
                )
960
961

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
962

963
    @staticmethod
964
    def startup_handshake(
965
966
        handshake_socket: zmq.Socket,
        local_client: bool,
967
        headless: bool,
968
        parallel_config: ParallelConfig | None = None,
969
    ) -> EngineZmqAddresses:
970
        # Send registration message.
971
        handshake_socket.send(
972
973
974
975
976
977
978
979
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
980
981

        # Receive initialization message.
982
        logger.debug("Waiting for init message from front-end.")
983
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
984
985
986
987
988
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
989
990
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
991
992
            init_bytes, type=EngineHandshakeMetadata
        )
993
994
        logger.debug("Received init message: %s", init_message)

995
996
997
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
998

999
        return init_message.addresses
1000
1001

    @staticmethod
1002
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
1003
1004
        """Launch EngineCore busy loop in background process."""

1005
1006
1007
1008
1009
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

1010
1011
1012
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

1023
        engine_core: EngineCoreProc | None = None
1024
        try:
1025
1026
1027
1028
1029
            vllm_config: VllmConfig = kwargs["vllm_config"]
            parallel_config: ParallelConfig = vllm_config.parallel_config
            data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
            if data_parallel:
                parallel_config.data_parallel_rank_local = local_dp_rank
1030
1031
1032
1033
1034
                maybe_init_worker_tracer(
                    instrumenting_module_name="vllm.engine_core",
                    process_kind="engine_core",
                    process_name=f"EngineCore_DP{dp_rank}",
                )
1035
                set_process_title("EngineCore", f"DP{dp_rank}")
1036
            else:
1037
1038
1039
1040
1041
                maybe_init_worker_tracer(
                    instrumenting_module_name="vllm.engine_core",
                    process_kind="engine_core",
                    process_name="EngineCore",
                )
1042
1043
1044
                set_process_title("EngineCore")
            decorate_logs()

1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
            if data_parallel and vllm_config.kv_transfer_config is not None:
                # modify the engine_id and append the local_dp_rank to it to ensure
                # that the kv_transfer_config is unique for each DP rank.
                vllm_config.kv_transfer_config.engine_id = (
                    f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
                )
                logger.debug(
                    "Setting kv_transfer_config.engine_id to %s",
                    vllm_config.kv_transfer_config.engine_id,
                )

1056
1057
            parallel_config.data_parallel_index = dp_rank
            if data_parallel and vllm_config.model_config.is_moe:
1058
1059
1060
1061
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
1062
1063
1064
1065
1066
1067
1068
                # Non-MoE DP ranks are completely independent, so treat like DP=1.
                # Note that parallel_config.data_parallel_index will still reflect
                # the original DP rank.
                parallel_config.data_parallel_size = 1
                parallel_config.data_parallel_size_local = 1
                parallel_config.data_parallel_rank = 0
                engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
1069

1070
            assert engine_core is not None
1071
1072
            engine_core.run_busy_loop()

1073
        except SystemExit:
1074
            logger.debug("EngineCore exiting.")
1075
            raise
1076
1077
1078
1079
1080
1081
1082
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
1083
1084
1085
1086
        finally:
            if engine_core is not None:
                engine_core.shutdown()

1087
1088
1089
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

1090
1091
1092
1093
1094
1095
1096
1097
    def has_work(self) -> bool:
        """Returns true if the engine should be stepped."""
        return (
            self.engines_running
            or self.scheduler.has_requests()
            or bool(self.batch_queue)
        )

1098
1099
1100
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

1101
1102
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
1103
            # 1) Poll the input queue until there is work to do.
1104
1105
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
1106
            self._process_engine_step()
1107
1108
1109
1110
1111

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
1112
1113
1114
        while not self.has_work():
            # Notify callbacks waiting for engine to become idle.
            self._notify_idle_state_callbacks()
1115
1116
1117
1118
1119
1120
1121
            if self.input_queue.empty():
                # Drain aborts queue; all aborts are also processed via input_queue.
                with self.aborts_queue.mutex:
                    self.aborts_queue.queue.clear()
                if logger.isEnabledFor(DEBUG):
                    logger.debug("EngineCore waiting for work.")
                    waited = True
1122
1123
1124
1125
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
1126
            logger.debug("EngineCore loop active.")
1127
1128
1129
1130
1131
1132

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

1133
    def _process_engine_step(self) -> bool:
1134
1135
1136
        """Called only when there are unfinished local requests."""

        # Step the engine core.
1137
        outputs, model_executed = self.step_fn()
1138
        # Put EngineCoreOutputs into the output queue.
1139
        for output in outputs.items() if outputs else ():
1140
            self.output_queue.put_nowait(output)
1141
1142
        # Post-step hook.
        self.post_step(model_executed)
1143

1144
1145
1146
1147
1148
1149
1150
        # If no model execution happened but there are waiting requests
        # (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow
        # background threads (like NIXL handshake) to make progress.
        # Without this, the tight polling loop can starve background threads.
        if not model_executed and self.scheduler.has_unfinished_requests():
            time.sleep(0.001)

1151
1152
        return model_executed

1153
1154
1155
1156
    def _notify_idle_state_callbacks(self) -> None:
        while self._idle_state_callbacks:
            callback = self._idle_state_callbacks.pop()
            callback(self)
1157

1158
1159
1160
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1161
        """Dispatch request from client."""
1162

1163
        if request_type == EngineCoreRequestType.ADD:
1164
1165
            req, request_wave = request
            self.add_request(req, request_wave)
1166
        elif request_type == EngineCoreRequestType.ABORT:
1167
            self.abort_requests(request)
1168
        elif request_type == EngineCoreRequestType.UTILITY:
1169
            client_idx, call_id, method_name, args = request
1170
            output = UtilityOutput(call_id)
1171
1172
1173
1174
1175
1176
            # Lazily look-up utility method so that failure will be handled/returned.
            get_result = lambda: (method := getattr(self, method_name)) and method(
                *self._convert_msgspec_args(method, args)
            )
            enqueue_output = lambda out: self.output_queue.put_nowait(
                (client_idx, EngineCoreOutputs(utility_output=out))
1177
            )
1178
            self._invoke_utility_method(method_name, get_result, output, enqueue_output)
1179
1180
1181
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
1182
1183
1184
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
1185

1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
    @staticmethod
    def _invoke_utility_method(
        name: str, get_result: Callable, output: UtilityOutput, enqueue_output: Callable
    ):
        try:
            result = get_result()
            if isinstance(result, Future):
                # Defer utility output handling until future completion.
                callback = lambda future: EngineCoreProc._invoke_utility_method(
                    name, future.result, output, enqueue_output
                )
                result.add_done_callback(callback)
                return
            output.result = UtilityResult(result)
        except Exception as e:
            logger.exception("Invocation of %s method failed", name)
            output.failure_message = f"Call to {name} method failed: {str(e)}"
        enqueue_output(output)

1205
1206
1207
    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
1208
        arg type, try converting to msgspec object."""
1209
1210
1211
1212
1213
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
1214
1215
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
1216
            and issubclass(p.annotation, msgspec.Struct)
1217
1218
1219
1220
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
1221

1222
1223
1224
1225
1226
1227
1228
1229
1230
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
1231
1232
1233
1234
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
1235

1236
1237
1238
    def process_input_sockets(
        self,
        input_addresses: list[str],
1239
        coord_input_address: str | None,
1240
1241
1242
        identity: bytes,
        ready_event: threading.Event,
    ):
1243
1244
1245
        """Input socket IO thread."""

        # Msgpack serialization decoding.
1246
1247
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
1248

1249
1250
1251
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
1252
1253
1254
1255
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
1256
1257
1258
1259
1260
1261
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
1262
1263
1264
1265
1266
1267
1268
1269
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
1270
                # Send subscription message to coordinator.
1271
                coord_socket.send(b"\x01")
1272
1273
1274
1275
1276
1277
1278

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
1279
                input_socket.send(b"")
1280
                poller.register(input_socket, zmq.POLLIN)
1281

1282
            if coord_socket is not None:
1283
1284
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
1285
                poller.register(coord_socket, zmq.POLLIN)
1286

1287
1288
            ready_event.set()
            del ready_event
1289
1290
1291
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
1292
1293
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
1294
1295

                    # Deserialize the request data.
1296
                    request: Any
1297
                    if request_type == EngineCoreRequestType.ADD:
1298
1299
1300
1301
1302
1303
                        req: EngineCoreRequest = add_request_decoder.decode(data_frames)
                        try:
                            request = self.preprocess_add_request(req)
                        except Exception:
                            self._handle_request_preproc_error(req)
                            continue
1304
1305
                    else:
                        request = generic_decoder.decode(data_frames)
1306

1307
1308
1309
1310
1311
1312
1313
                        if request_type == EngineCoreRequestType.ABORT:
                            # Aborts are added to *both* queues, allows us to eagerly
                            # process aborts while also ensuring ordering in the input
                            # queue to avoid leaking requests. This is ok because
                            # aborting in the scheduler is idempotent.
                            self.aborts_queue.put_nowait(request)

1314
1315
1316
                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

1317
1318
1319
    def process_output_sockets(
        self,
        output_paths: list[str],
1320
        coord_output_path: str | None,
1321
1322
        engine_index: int,
    ):
1323
1324
1325
        """Output socket IO thread."""

        # Msgpack serialization encoding.
1326
        encoder = MsgpackEncoder()
1327
1328
1329
1330
1331
1332
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
1333

1334
1335
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1336
1337
1338
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1339
1340
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1341
1342
                for output_path in output_paths
            ]
1343
1344
1345
1346
1347
1348
1349
1350
1351
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1352
1353
            max_reuse_bufs = len(sockets) + 1

1354
            while True:
1355
1356
1357
1358
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1359
                    break
1360
1361
                assert not isinstance(output, bytes)
                client_index, outputs = output
1362
                outputs.engine_index = engine_index
1363

1364
1365
1366
1367
1368
1369
1370
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1371
1372
1373
1374
1375
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1376
                buffers = encoder.encode_into(outputs, buffer)
1377
1378
1379
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1380
1381
1382
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1383
1384
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1385
                    reuse_buffers.append(buffer)
1386

1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
    def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None:
        """Log and return a request-scoped error response for exceptions raised
        from the add request preprocessing in the input socket processing thread.
        """
        logger.exception(
            "Unexpected error pre-processing request %s", request.request_id
        )
        self.output_queue.put_nowait(
            (
                request.client_index,
                EngineCoreOutputs(
                    engine_index=self.engine_index,
                    finished_requests={request.request_id},
                    outputs=[
                        EngineCoreOutput(
                            request_id=request.request_id,
                            new_token_ids=[],
                            finish_reason=FinishReason.ERROR,
                        )
                    ],
                ),
            )
        )

1411
1412
1413
1414
1415
    def pause_scheduler(
        self, mode: PauseMode = "abort", clear_cache: bool = True
    ) -> Future | None:
        """Pause generation; behavior depends on mode.

1416
1417
1418
1419
1420
1421
1422
1423
1424
        All pause modes queue new adds -- "abort" and "keep" skip step();
        "wait" allows step() so in-flight requests can drain.

        - ``abort``: Set PAUSED_NEW, abort all requests, wait for abort
          outputs to be sent (when running with output_queue), optionally
          clear caches, then complete the returned Future.
        - ``wait``: Set PAUSED_NEW (queue adds, keep stepping); when drained,
          optionally clear caches, then complete the returned Future.
        - ``keep``: Set PAUSED_ALL; return a Future that completes when the
1425
1426
1427
1428
1429
          output queue is empty.
        """
        if mode not in ("keep", "abort", "wait"):
            raise ValueError(f"Invalid pause mode: {mode}")

1430
        def engine_idle_callback(engine: "EngineCoreProc", future: Future[Any]) -> None:
1431
            if clear_cache:
1432
                engine._reset_caches()
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
            future.set_result(None)

        if mode == "abort":
            aborted_reqs = self.scheduler.finish_requests(
                None, RequestStatus.FINISHED_ABORTED
            )
            self._send_abort_outputs(aborted_reqs)

        pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
        self.scheduler.set_pause_state(pause_state)
1443
1444
1445
1446
1447
1448
1449
1450
        if not self.has_work():
            if clear_cache:
                self._reset_caches()
            return None

        future = Future[Any]()
        self._idle_state_callbacks.append(partial(engine_idle_callback, future=future))
        return future
1451
1452

    def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None:
1453
        # TODO(nick) this will be moved inside the scheduler
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
        if aborted_reqs:
            # Map client_index to list of request_ids that belong to that client.
            by_client = defaultdict[int, set[str]](set)
            for req_id, client_index in aborted_reqs:
                by_client[client_index].add(req_id)
            for client_index, req_ids in by_client.items():
                outputs = [
                    EngineCoreOutput(req_id, [], finish_reason=FinishReason.ABORT)
                    for req_id in req_ids
                ]
                eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
                self.output_queue.put_nowait((client_index, eco))

1467
1468
1469
1470
1471
1472
1473
1474

class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1475
        local_client: bool,
1476
        handshake_address: str,
1477
1478
        executor_class: type[Executor],
        log_stats: bool,
1479
        client_handshake_address: str | None = None,
1480
    ):
1481
1482
1483
1484
        assert vllm_config.model_config.is_moe, (
            "DPEngineCoreProc should only be used for MoE models"
        )

1485
1486
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1487
        self.step_counter = 0
1488
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1489
        self.last_counts = (0, 0)
1490
1491
1492

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1493
1494
1495
1496
1497
1498
1499
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
1500
            engine_index=dp_rank,
1501
        )
1502
1503
1504

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1505
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1506
        dp_size = vllm_config.parallel_config.data_parallel_size
1507
1508
1509
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
1510
        assert local_dp_rank is not None
1511
1512
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1513
        self.dp_rank = dp_rank
1514
1515
1516
1517
1518
1519
1520
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1521
    def add_request(self, request: Request, request_wave: int = 0):
1522
        super().add_request(request, request_wave)
1523
1524
1525
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1526
1527
1528
1529
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1530
1531
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1532

1533
1534
1535
1536
1537
1538
1539
    def resume_scheduler(self):
        super().resume_scheduler()
        if not self.engines_running and self.scheduler.has_unfinished_requests():
            # Wake up other DP engines.
            self.output_queue.put_nowait(
                (-1, EngineCoreOutputs(start_wave=self.current_wave))
            )
1540

1541
1542
1543
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1544
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1545
1546
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1547
1548
                new_wave >= self.current_wave
            ):
1549
1550
                self.current_wave = new_wave
                if not self.engines_running:
1551
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1552
1553
1554
1555
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1556
    def _maybe_publish_request_counts(self):
1557
        if not self.publish_dp_lb_stats:
1558
1559
1560
1561
1562
1563
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1564
1565
1566
1567
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1568

1569
1570
1571
1572
1573
1574
1575
1576
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1577
1578
            # 2) Step the engine core.
            executed = self._process_engine_step()
1579
            self._maybe_publish_request_counts()
1580

1581
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1582
1583
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1584
1585
1586
                    # All engines are idle.
                    continue

1587
1588
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1589
1590
1591
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1592
            self.engines_running = self._has_global_unfinished_reqs(
1593
1594
                local_unfinished_reqs
            )
1595

1596
            if not self.engines_running:
1597
                if self.dp_rank == 0 or not self.has_coordinator:
1598
                    # Notify client that we are pausing the loop.
1599
1600
1601
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1602
1603
1604
1605
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1606
                    self.output_queue.put_nowait(
1607
1608
1609
1610
1611
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1612
                # Increment wave count and reset step counter.
1613
                self.current_wave += 1
1614
                self.step_counter = 0
1615
1616

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1617
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1618
1619
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1620
1621
            return True

1622
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1623

1624
    def reinitialize_distributed(
1625
1626
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1627
1628
1629
1630
1631
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1632
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1633
        if reconfig_request.new_data_parallel_rank != -1:
1634
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1635
        # local rank specifies device visibility, it should not be changed
1636
1637
1638
1639
1640
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1641
            reconfig_request.new_data_parallel_master_ip
1642
1643
        )
        parallel_config.data_parallel_master_port = (
1644
            reconfig_request.new_data_parallel_master_port
1645
        )
1646
1647
1648
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1649
        reconfig_request.new_data_parallel_master_port = (
1650
            parallel_config.data_parallel_master_port
1651
        )
1652
1653
1654
1655
1656
1657
1658
1659

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1660
1661
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1662
1663
1664
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1665
1666
1667
1668
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1669
1670
1671
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1672
1673
1674
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1675

Rui Qiao's avatar
Rui Qiao committed
1676

1677
class EngineCoreActorMixin:
Rui Qiao's avatar
Rui Qiao committed
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        addresses: EngineZmqAddresses,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
1689
1690
1691
1692
1693
1694
1695
        # Initialize tracer for distributed tracing if configured.
        maybe_init_worker_tracer(
            instrumenting_module_name="vllm.engine_core",
            process_kind="engine_core",
            process_name=f"DPEngineCoreActor_DP{dp_rank}",
        )

Rui Qiao's avatar
Rui Qiao committed
1696
        self.addresses = addresses
1697
        vllm_config.parallel_config.data_parallel_index = dp_rank
1698
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1699

1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1710
1711
1712
1713
1714
1715
1716
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1717
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1718

1719
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1720
        from vllm.platforms import current_platform
1721

1722
1723
1724
1725
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1726
1727
1728
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1729

1730
1731
1732
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1733
1734
1735
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1736
1737
1738
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1739
            os.environ[device_control_env_var] = value
1740
1741
1742
1743
1744
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1745
1746
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1747

Rui Qiao's avatar
Rui Qiao committed
1748
    @contextmanager
1749
1750
1751
1752
1753
1754
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1755
        client_handshake_address: str | None,
1756
    ):
Rui Qiao's avatar
Rui Qiao committed
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
1779
            self.run_busy_loop()  # type: ignore[attr-defined]
Rui Qiao's avatar
Rui Qiao committed
1780
1781
1782
1783
1784
1785
1786
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
            self.shutdown()  # type: ignore[attr-defined]


class DPMoEEngineCoreActor(EngineCoreActorMixin, DPEngineCoreProc):
    """Used for MoE model data parallel cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_rank = dp_rank

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        DPEngineCoreProc.__init__(
            self, vllm_config, local_client, "", executor_class, log_stats
        )


class EngineCoreActor(EngineCoreActorMixin, EngineCoreProc):
    """Used for non-MoE and/or non-DP cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_size = 1
        vllm_config.parallel_config.data_parallel_size_local = 1
        vllm_config.parallel_config.data_parallel_rank = 0

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        EngineCoreProc.__init__(
            self,
            vllm_config,
            local_client,
            "",
            executor_class,
            log_stats,
            engine_index=dp_rank,
        )