core.py 71.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
7
import threading
import time
8
from collections import defaultdict, deque
9
from collections.abc import Callable, Generator
10
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
11
from contextlib import ExitStack, contextmanager
12
from inspect import isclass, signature
13
from logging import DEBUG
14
from typing import Any, TypeVar, cast
15

16
import msgspec
17
18
import zmq

19
20
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
21
from vllm.envs import enable_envs_cache
22
from vllm.logger import init_logger
23
from vllm.logging_utils.dump_input import dump_engine_exception
24
from vllm.lora.request import LoRARequest
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.tasks import POOLING_TASKS, SupportedTask
27
from vllm.tracing import instrument, maybe_init_worker_tracer
28
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
29
30
31
32
from vllm.utils.gc_utils import (
    freeze_gc_heap,
    maybe_attach_gc_debug_callback,
)
33
from vllm.utils.hashing import get_hash_fn_by_name
34
from vllm.utils.network_utils import make_zmq_socket
35
from vllm.utils.system_utils import decorate_logs, set_process_title
36
37
38
39
40
41
42
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
43
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
44
from vllm.v1.core.sched.output import SchedulerOutput
45
from vllm.v1.engine import (
46
    EngineCoreOutput,
47
48
49
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
50
    FinishReason,
51
    PauseMode,
52
53
54
55
56
57
58
59
60
61
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
62
from vllm.v1.executor import Executor
63
from vllm.v1.kv_cache_interface import KVCacheConfig
64
from vllm.v1.metrics.stats import SchedulerStats
65
from vllm.v1.outputs import ModelRunnerOutput
66
from vllm.v1.request import Request, RequestStatus
67
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
68
from vllm.v1.structured_output import StructuredOutputManager
69
from vllm.v1.utils import compute_iteration_details
70
71
72
73
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

74
POLLING_TIMEOUT_S = 2.5
75
HANDSHAKE_TIMEOUT_MINS = 5
76

77
_R = TypeVar("_R")  # Return type for collective_rpc
78

79
80
81
82

class EngineCore:
    """Inner loop of vLLM's Engine."""

83
84
85
86
87
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
88
        executor_fail_callback: Callable | None = None,
89
        include_finished_set: bool = False,
90
    ):
91
92
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
93

94
95
        load_general_plugins()

96
        self.vllm_config = vllm_config
97
        if not vllm_config.parallel_config.data_parallel_rank_local:
98
99
100
101
102
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
103

104
105
        self.log_stats = log_stats

106
107
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
108
        if executor_fail_callback is not None:
109
            self.model_executor.register_failure_callback(executor_fail_callback)
110

111
112
        self.available_gpu_memory_for_kv_cache = -1

113
        # Setup KV Caches and update CacheConfig after profiling.
114
115
116
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
117
118
119
120
121
122
123
124
        if kv_cache_config.kv_cache_groups:
            vllm_config.cache_config.block_size = min(
                g.kv_cache_spec.block_size for g in kv_cache_config.kv_cache_groups
            )
        elif vllm_config.cache_config.block_size is None:
            # Attention-free models (encoder-only, SSM) — use default.
            vllm_config.cache_config.block_size = 16
        vllm_config.validate_block_size()
125
126
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
127
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
128

129
130
        self.structured_output_manager = StructuredOutputManager(vllm_config)

131
        # Setup scheduler.
132
        Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
133

134
        if len(kv_cache_config.kv_cache_groups) == 0:  # noqa: SIM102
135
136
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
137
138
139
            if vllm_config.scheduler_config.enable_chunked_prefill:
                logger.warning("Disabling chunked prefill for model without KVCache")
                vllm_config.scheduler_config.enable_chunked_prefill = False
140

141
142
143
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
144
            * vllm_config.parallel_config.prefill_context_parallel_size
145
146
        )

147
        self.scheduler: SchedulerInterface = Scheduler(
148
            vllm_config=vllm_config,
149
150
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
151
            include_finished_set=include_finished_set,
152
            log_stats=self.log_stats,
153
            block_size=scheduler_block_size,
154
        )
155
        self.use_spec_decode = vllm_config.speculative_config is not None
156
        if self.scheduler.connector is not None:  # type: ignore
157
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
158

159
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
160
161
        self.mm_receiver_cache = mm_registry.engine_receiver_cache_from_config(
            vllm_config
162
        )
163

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        # If a KV connector is initialized for scheduler, we want to collect
        # handshake metadata from all workers so the connector in the scheduler
        # will have the full context
        kv_connector = self.scheduler.get_kv_connector()
        if kv_connector is not None:
            # Collect and store KV connector xfer metadata from workers
            # (after KV cache registration)
            xfer_handshake_metadata = (
                self.model_executor.get_kv_connector_handshake_metadata()
            )

            if xfer_handshake_metadata:
                # xfer_handshake_metadata is list of dicts from workers
                # Each dict already has structure {tp_rank: metadata}
                # Merge all worker dicts into a single dict
                content: dict[int, Any] = {}
                for worker_dict in xfer_handshake_metadata:
                    if worker_dict is not None:
                        content.update(worker_dict)
                kv_connector.set_xfer_handshake_metadata(content)

185
186
187
188
189
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
190
        self.batch_queue: (
191
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput, Future[Any]]] | None
192
        ) = None
193
        if self.batch_queue_size > 1:
194
            logger.debug("Batch queue is enabled with size %d", self.batch_queue_size)
195
            self.batch_queue = deque(maxlen=self.batch_queue_size)
196

197
        self.is_ec_producer = (
198
199
200
            vllm_config.ec_transfer_config is not None
            and vllm_config.ec_transfer_config.is_ec_producer
        )
201
        self.is_pooling_model = vllm_config.model_config.runner_type == "pooling"
202

203
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
204
        if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
205
            caching_hash_fn = get_hash_fn_by_name(
206
207
                vllm_config.cache_config.prefix_caching_hash_algo
            )
208
209
210
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
211
                scheduler_block_size, caching_hash_fn
212
            )
213

214
215
216
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
217
        self.async_scheduling = vllm_config.scheduler_config.async_scheduling
218

219
        self.aborts_queue = queue.Queue[list[str]]()
220

221
        self.per_step_hooks: set[Callable] = set()
222

223
224
225
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        freeze_gc_heap()
226
227
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()
228
229
230
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()
231

232
    @instrument(span_name="Prepare model")
233
    def _initialize_kv_caches(
234
235
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
236
        start = time.time()
237

238
        # Get all kv cache needed by the model
239
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
240

241
242
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
243
244
245
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
246
                self.available_gpu_memory_for_kv_cache = (
247
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
248
249
250
251
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
252
253
254
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
255
256
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
257
258
259
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
260

261
        assert len(kv_cache_specs) == len(available_gpu_memory)
262

263
264
265
        # Track max_model_len before KV cache config to detect auto-fit changes
        max_model_len_before = vllm_config.model_config.max_model_len

266
267
268
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
269
270
271
272
273
274
275
276

        # If auto-fit reduced max_model_len, sync the new value to workers.
        # This is needed because workers were spawned before memory profiling
        # and have the original (larger) max_model_len cached.
        max_model_len_after = vllm_config.model_config.max_model_len
        if max_model_len_after != max_model_len_before:
            self.collective_rpc("update_max_model_len", args=(max_model_len_after,))

277
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
278
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
279
        num_cpu_blocks = 0
280
281

        # Initialize kv cache and warmup the execution
282
        self.model_executor.initialize_from_config(kv_cache_configs)
283

284
        elapsed = time.time() - start
285
        logger.info_once(
286
            "init engine (profile, create kv cache, warmup model) took %.2f seconds",
287
            elapsed,
288
            scope="local",
289
        )
290
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
291

292
293
294
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

295
296
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
297

298
299
300
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
301
302
303
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
304
305
                f"request_id must be a string, got {type(request.request_id)}"
            )
306

307
        if pooling_params := request.pooling_params:
308
            supported_pooling_tasks = [
309
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
310
311
            ]

312
            if pooling_params.task not in supported_pooling_tasks:
313
314
315
316
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
317

318
        if request.kv_transfer_params is not None and (
319
320
321
322
323
324
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
325

326
        self.scheduler.add_request(request)
327

328
    def abort_requests(self, request_ids: list[str]):
329
330
331
332
333
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
334
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
335

336
337
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
338
        """Execute the model and log detailed info on failure."""
339
        try:
340
            yield
341
342
343
344
345
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

346
            # NOTE: This method is exception-free
347
348
349
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
350
351
            raise err

352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    @contextmanager
    def log_iteration_details(self, scheduler_output: SchedulerOutput):
        if not self.vllm_config.observability_config.enable_logging_iteration_details:
            yield
            return
        self._iteration_index = getattr(self, "_iteration_index", 0)
        iteration_details = compute_iteration_details(scheduler_output)
        before = time.monotonic()
        yield
        logger.info(
            "".join(
                [
                    "Iteration(",
                    str(self._iteration_index),
                    "): ",
                    str(iteration_details.num_ctx_requests),
                    " context requests, ",
                    str(iteration_details.num_ctx_tokens),
                    " context tokens, ",
                    str(iteration_details.num_generation_requests),
                    " generation requests, ",
                    str(iteration_details.num_generation_tokens),
                    " generation tokens, iteration elapsed time: ",
                    format((time.monotonic() - before) * 1000, ".2f"),
                    " ms",
                ]
            )
        )
        self._iteration_index += 1

382
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
383
384
385
386
387
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
388

389
390
391
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
392
            return {}, False
393
394
395
        scheduler_output = self.scheduler.schedule()
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
396
397
398
399
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
400
401
402
403
            model_output = future.result()
            if model_output is None:
                model_output = self.model_executor.sample_tokens(grammar_output)

404
405
406
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
407
408
409
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
410

411
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
412

413
    def post_step(self, model_executed: bool) -> None:
414
415
416
417
        # When using async scheduling we can't get draft token ids in advance,
        # so we update draft token ids in the worker process and don't
        # need to update draft token ids here.
        if not self.async_scheduling and self.use_spec_decode and model_executed:
418
419
420
421
422
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

423
    def step_with_batch_queue(
424
        self,
425
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
426
427
428
429
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
430
431
432
433
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
434
435
436
437
438
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
439

440
441
        batch_queue = self.batch_queue
        assert batch_queue is not None
442

443
444
445
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
446
        assert len(batch_queue) < self.batch_queue_size
447

448
        model_executed = False
449
        deferred_scheduler_output = None
450
        if self.scheduler.has_requests():
451
452
453
454
            scheduler_output = self.scheduler.schedule()
            exec_future = self.model_executor.execute_model(
                scheduler_output, non_block=True
            )
455
            if not self.is_ec_producer:
456
                model_executed = scheduler_output.total_num_scheduled_tokens > 0
457

458
            if self.is_pooling_model or not model_executed:
459
460
                # No sampling required (no requests scheduled).
                future = cast(Future[ModelRunnerOutput], exec_future)
461
            else:
462
463
464
                if not scheduler_output.pending_structured_output_tokens:
                    # We aren't waiting for any tokens, get any grammar output
                    # and sample immediately.
465
466
467
                    grammar_output = self.scheduler.get_grammar_bitmask(
                        scheduler_output
                    )
468
469
470
                    future = self.model_executor.sample_tokens(
                        grammar_output, non_block=True
                    )
471
                else:
472
473
474
475
476
                    # We need to defer sampling until we have processed the model output
                    # from the prior step.
                    deferred_scheduler_output = scheduler_output

            if not deferred_scheduler_output:
477
                # Add this step's future to the queue.
478
                batch_queue.appendleft((future, scheduler_output, exec_future))
479
480
481
482
483
484
485
486
                if (
                    model_executed
                    and len(batch_queue) < self.batch_queue_size
                    and not batch_queue[-1][0].done()
                ):
                    # Don't block on next worker response unless the queue is full
                    # or there are no more requests to schedule.
                    return None, True
487
488
489
490
491
492

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False
493
494

        # Block until the next result is available.
495
        future, scheduler_output, exec_model_fut = batch_queue.pop()
496
497
498
499
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
500
            model_output = future.result()
501
502
503
504
505
            if model_output is None:
                # None from sample_tokens() implies that the original execute_model()
                # call failed - raise that exception.
                exec_model_fut.result()
                raise RuntimeError("unexpected error")
506

507
508
509
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
510
511
512
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
513
514
515
516
517

        # NOTE(nick): We can either handle the deferred tasks here or save
        # in a field and do it immediately once step_with_batch_queue is
        # re-called. The latter slightly favors TTFT over TPOT/throughput.
        if deferred_scheduler_output:
518
519
520
521
522
523
524
525
526
527
528
529
            # If we are doing speculative decoding with structured output,
            # we need to get the draft token ids from the prior step before
            # we can compute the grammar bitmask for the deferred request.
            if self.use_spec_decode:
                draft_token_ids = self.model_executor.take_draft_token_ids()
                assert draft_token_ids is not None
                # Update the draft token ids in the scheduler output to
                # filter out the invalid spec tokens, which will be padded
                # with -1 and skipped by the grammar bitmask computation.
                self.scheduler.update_draft_token_ids_in_output(
                    draft_token_ids, deferred_scheduler_output
                )
530
531
532
533
534
535
            # We now have the tokens needed to compute the bitmask for the
            # deferred request. Get the bitmask and call sample tokens.
            grammar_output = self.scheduler.get_grammar_bitmask(
                deferred_scheduler_output
            )
            future = self.model_executor.sample_tokens(grammar_output, non_block=True)
536
            batch_queue.appendleft((future, deferred_scheduler_output, exec_future))
537

538
        return engine_core_outputs, model_executed
539

540
541
542
543
544
    def _process_aborts_queue(self):
        if not self.aborts_queue.empty():
            request_ids = []
            while not self.aborts_queue.empty():
                ids = self.aborts_queue.get_nowait()
545
546
                # Should be a list here, but also handle string just in case.
                request_ids.extend((ids,) if isinstance(ids, str) else ids)
547
548
549
            # More efficient to abort all as a single batch.
            self.abort_requests(request_ids)

550
    def shutdown(self):
551
        self.structured_output_manager.clear_backend()
552
553
        if self.model_executor:
            self.model_executor.shutdown()
554
555
        if self.scheduler:
            self.scheduler.shutdown()
556

557
558
    def profile(self, is_start: bool = True, profile_prefix: str | None = None):
        self.model_executor.profile(is_start, profile_prefix)
559

560
561
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
562
        # re-sync the internal caches (P0 sender, P1 receiver)
563
        if self.scheduler.has_unfinished_requests():
564
565
566
567
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
568

569
        # The cache either exists in EngineCore or WorkerWrapperBase
570
571
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
572

573
574
        self.model_executor.reset_mm_cache()

575
576
577
578
579
580
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.scheduler.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
581

582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
    def reset_encoder_cache(self) -> None:
        """Reset the encoder cache to invalidate all cached encoder outputs.

        This should be called when model weights are updated to ensure
        stale vision embeddings computed with old weights are not reused.
        Clears both the scheduler's cache manager and the GPU model runner's cache.
        """
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 sender, P1 receiver)
        if self.scheduler.has_unfinished_requests():
            logger.warning(
                "Resetting the encoder cache when requests are "
                "in progress may lead to desynced internal caches."
            )

        # Reset the scheduler's encoder cache manager (logical state)
        self.scheduler.reset_encoder_cache()
        # Reset the GPU model runner's encoder cache (physical storage)
        self.model_executor.reset_encoder_cache()

602
603
604
605
606
607
608
609
610
611
612
613
614
615
    def pause_scheduler(
        self, mode: PauseMode = "abort", clear_cache: bool = True
    ) -> Future[Any] | None:
        """Pause scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""
        return None

    def resume_scheduler(self) -> None:
        """Resume scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""

    def is_scheduler_paused(self) -> bool:
        """Return whether the scheduler is in any pause state. False in base EngineCore
        and overridden in EngineCoreProc."""
        return False

616
    def sleep(self, level: int = 1):
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
        """Put the engine to sleep at the specified level.

        Args:
            level: Sleep level.
                - Level 0: Pause scheduling only. Requests are still accepted
                           but not processed. No GPU memory changes.
                - Level 1: Offload model weights to CPU, discard KV cache.
                - Level 2: Discard all GPU memory.
        """
        if level == 0:
            # Level 0: Just pause scheduling, don't touch GPU
            self.pause_scheduler()
        else:
            # Level 1+: Delegate to executor for GPU memory management
            self.model_executor.sleep(level)
632

633
    def wake_up(self, tags: list[str] | None = None):
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
        """Wake up the engine from sleep.

        Args:
            tags: Tags to wake up. Use ["scheduling"] for level 0 wake up.
        """
        if tags is not None and "scheduling" in tags:
            # Level 0 wake up: Resume scheduling
            self.resume_scheduler()
            # Remove "scheduling" from tags if there are other tags to process
            remaining_tags = [t for t in tags if t != "scheduling"]
            if remaining_tags:
                self.model_executor.wake_up(remaining_tags)
        else:
            # Full wake up
            self.resume_scheduler()
            self.model_executor.wake_up(tags)
650

651
    def is_sleeping(self) -> bool:
652
        """Check if engine is sleeping at any level."""
653
        return self.is_scheduler_paused() or self.model_executor.is_sleeping
654

655
    def execute_dummy_batch(self):
656
        self.model_executor.execute_dummy_batch()
657

658
659
660
661
662
663
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

664
    def list_loras(self) -> set[int]:
665
666
667
668
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
669

670
671
672
    def save_sharded_state(
        self,
        path: str,
673
674
        pattern: str | None = None,
        max_size: int | None = None,
675
    ) -> None:
676
677
678
679
680
681
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
682
683
        method: str | Callable[..., _R],
        timeout: float | None = None,
684
        args: tuple = (),
685
        kwargs: dict[str, Any] | None = None,
686
687
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
688

689
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
690
        """Preprocess the request.
691

692
693
694
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
695
696
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
697
        # and will only be accessed in the input processing thread afterwards.
698
        if self.mm_receiver_cache is not None and request.mm_features:
699
700
701
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
702

703
        req = Request.from_engine_core_request(request, self.request_block_hasher)
704
705
706
707
708
709
710
711
712
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

713
714
715
716

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

717
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
718

719
    @instrument(span_name="EngineCoreProc init")
720
721
    def __init__(
        self,
722
        vllm_config: VllmConfig,
723
        local_client: bool,
724
        handshake_address: str,
725
        executor_class: type[Executor],
726
        log_stats: bool,
727
        client_handshake_address: str | None = None,
728
        *,
729
        engine_index: int = 0,
730
    ):
Rui Qiao's avatar
Rui Qiao committed
731
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
732
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
733
        executor_fail_callback = lambda: self.input_queue.put_nowait(
734
735
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
736

Rui Qiao's avatar
Rui Qiao committed
737
738
739
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
740

741
742
743
744
745
746
747
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
748
            self.client_count = len(addresses.outputs)
749
750

            # Set up data parallel environment.
751
            self.has_coordinator = addresses.coordinator_output is not None
752
            self.frontend_stats_publish_address = (
753
754
755
756
757
758
759
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
760
            internal_dp_balancing = (
761
                self.has_coordinator
762
763
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
764
765
766
            # Only publish request queue stats to coordinator for "internal"
            # and "hybrid" LB modes.
            self.publish_dp_lb_stats = internal_dp_balancing
767

768
769
            self._init_data_parallel(vllm_config)

770
            super().__init__(
771
772
773
774
775
                vllm_config,
                executor_class,
                log_stats,
                executor_fail_callback,
                internal_dp_balancing,
776
            )
777

778
779
780
781
782
783
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
784
785
786
787
788
789
790
791
792
793
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
794
795
796
797
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
798
799
800
801
802
803
804
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
805
806
807
808
809
810
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
811
                    raise RuntimeError("Input socket thread died during startup")
812
813
814
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

Rui Qiao's avatar
Rui Qiao committed
815
    @contextmanager
816
817
818
819
820
821
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
822
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
823
    ) -> Generator[EngineZmqAddresses, None, None]:
824
825
826
827
828
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

829
        For DP>1 with internal load-balancing this is with the shared front-end
830
831
        process which may reside on a different node.

832
        For DP>1 with external or hybrid load-balancing, two handshakes are
833
        performed:
834
835
836
837
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
838
839
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
840
841
842
843
844
845

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
846
        input_ctx = zmq.Context()
847
        is_local = local_client and client_handshake_address is None
848
        headless = not local_client
849
850
851
852
853
854
855
856
857
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
858
859
860
861
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
862
            assert local_client
863
            local_handshake = self._perform_handshake(
864
865
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
881
        headless: bool,
882
        vllm_config: VllmConfig,
883
        parallel_config_to_update: ParallelConfig | None = None,
884
    ) -> Generator[EngineZmqAddresses, None, None]:
885
886
887
888
889
890
891
892
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
893
            # Register engine with front-end.
894
895
896
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
897
898
899
900
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
901
902
903
904
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
905
906
907
908
909
910
911
912
913
914
915
916

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
917
                )
918
919

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
920

921
    @staticmethod
922
    def startup_handshake(
923
924
        handshake_socket: zmq.Socket,
        local_client: bool,
925
        headless: bool,
926
        parallel_config: ParallelConfig | None = None,
927
    ) -> EngineZmqAddresses:
928
        # Send registration message.
929
        handshake_socket.send(
930
931
932
933
934
935
936
937
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
938
939

        # Receive initialization message.
940
        logger.debug("Waiting for init message from front-end.")
941
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
942
943
944
945
946
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
947
948
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
949
950
            init_bytes, type=EngineHandshakeMetadata
        )
951
952
        logger.debug("Received init message: %s", init_message)

953
954
955
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
956

957
        return init_message.addresses
958
959

    @staticmethod
960
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
961
962
        """Launch EngineCore busy loop in background process."""

963
964
965
966
967
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

968
969
970
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

971
972
973
974
975
976
977
978
979
980
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

981
        engine_core: EngineCoreProc | None = None
982
        try:
983
984
985
986
987
            vllm_config: VllmConfig = kwargs["vllm_config"]
            parallel_config: ParallelConfig = vllm_config.parallel_config
            data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
            if data_parallel:
                parallel_config.data_parallel_rank_local = local_dp_rank
988
989
990
991
992
                maybe_init_worker_tracer(
                    instrumenting_module_name="vllm.engine_core",
                    process_kind="engine_core",
                    process_name=f"EngineCore_DP{dp_rank}",
                )
993
                set_process_title("EngineCore", f"DP{dp_rank}")
994
            else:
995
996
997
998
999
                maybe_init_worker_tracer(
                    instrumenting_module_name="vllm.engine_core",
                    process_kind="engine_core",
                    process_name="EngineCore",
                )
1000
1001
1002
                set_process_title("EngineCore")
            decorate_logs()

1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
            if data_parallel and vllm_config.kv_transfer_config is not None:
                # modify the engine_id and append the local_dp_rank to it to ensure
                # that the kv_transfer_config is unique for each DP rank.
                vllm_config.kv_transfer_config.engine_id = (
                    f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
                )
                logger.debug(
                    "Setting kv_transfer_config.engine_id to %s",
                    vllm_config.kv_transfer_config.engine_id,
                )

1014
1015
            parallel_config.data_parallel_index = dp_rank
            if data_parallel and vllm_config.model_config.is_moe:
1016
1017
1018
1019
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
1020
1021
1022
1023
1024
1025
1026
                # Non-MoE DP ranks are completely independent, so treat like DP=1.
                # Note that parallel_config.data_parallel_index will still reflect
                # the original DP rank.
                parallel_config.data_parallel_size = 1
                parallel_config.data_parallel_size_local = 1
                parallel_config.data_parallel_rank = 0
                engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
1027

1028
            assert engine_core is not None
1029
1030
            engine_core.run_busy_loop()

1031
        except SystemExit:
1032
            logger.debug("EngineCore exiting.")
1033
            raise
1034
1035
1036
1037
1038
1039
1040
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
1041
1042
1043
1044
        finally:
            if engine_core is not None:
                engine_core.shutdown()

1045
1046
1047
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

1048
1049
1050
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

1051
1052
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
1053
            # 1) Poll the input queue until there is work to do.
1054
1055
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
1056
1057
1058
            self._process_engine_step()
            # 3) Run any per-step hooks.
            self._process_per_step_hooks()
1059
1060
1061
1062
1063

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
1064
1065
        while (
            not self.engines_running
1066
            and not self.scheduler.has_requests()
1067
            and not self.batch_queue
1068
            and not self.per_step_hooks
1069
        ):
1070
1071
1072
1073
1074
1075
1076
            if self.input_queue.empty():
                # Drain aborts queue; all aborts are also processed via input_queue.
                with self.aborts_queue.mutex:
                    self.aborts_queue.queue.clear()
                if logger.isEnabledFor(DEBUG):
                    logger.debug("EngineCore waiting for work.")
                    waited = True
1077
1078
1079
1080
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
1081
            logger.debug("EngineCore loop active.")
1082
1083
1084
1085
1086
1087

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

1088
    def _process_engine_step(self) -> bool:
1089
1090
1091
        """Called only when there are unfinished local requests."""

        # Step the engine core.
1092
        outputs, model_executed = self.step_fn()
1093
        # Put EngineCoreOutputs into the output queue.
1094
        for output in outputs.items() if outputs else ():
1095
            self.output_queue.put_nowait(output)
1096
1097
        # Post-step hook.
        self.post_step(model_executed)
1098

1099
1100
1101
1102
1103
1104
1105
        # If no model execution happened but there are waiting requests
        # (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow
        # background threads (like NIXL handshake) to make progress.
        # Without this, the tight polling loop can starve background threads.
        if not model_executed and self.scheduler.has_unfinished_requests():
            time.sleep(0.001)

1106
1107
        return model_executed

1108
1109
1110
1111
1112
1113
1114
    def _process_per_step_hooks(self) -> None:
        if self.per_step_hooks:
            for hook in list(self.per_step_hooks):
                finished = hook(self)
                if finished:
                    self.per_step_hooks.discard(hook)

1115
1116
1117
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1118
        """Dispatch request from client."""
1119

1120
        if request_type == EngineCoreRequestType.ADD:
1121
1122
            req, request_wave = request
            self.add_request(req, request_wave)
1123
        elif request_type == EngineCoreRequestType.ABORT:
1124
            self.abort_requests(request)
1125
        elif request_type == EngineCoreRequestType.UTILITY:
1126
            client_idx, call_id, method_name, args = request
1127
            output = UtilityOutput(call_id)
1128
1129
1130
1131
1132
1133
            # Lazily look-up utility method so that failure will be handled/returned.
            get_result = lambda: (method := getattr(self, method_name)) and method(
                *self._convert_msgspec_args(method, args)
            )
            enqueue_output = lambda out: self.output_queue.put_nowait(
                (client_idx, EngineCoreOutputs(utility_output=out))
1134
            )
1135
            self._invoke_utility_method(method_name, get_result, output, enqueue_output)
1136
1137
1138
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
1139
1140
1141
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
1142

1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
    @staticmethod
    def _invoke_utility_method(
        name: str, get_result: Callable, output: UtilityOutput, enqueue_output: Callable
    ):
        try:
            result = get_result()
            if isinstance(result, Future):
                # Defer utility output handling until future completion.
                callback = lambda future: EngineCoreProc._invoke_utility_method(
                    name, future.result, output, enqueue_output
                )
                result.add_done_callback(callback)
                return
            output.result = UtilityResult(result)
        except Exception as e:
            logger.exception("Invocation of %s method failed", name)
            output.failure_message = f"Call to {name} method failed: {str(e)}"
        enqueue_output(output)

1162
1163
1164
    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
1165
        arg type, try converting to msgspec object."""
1166
1167
1168
1169
1170
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
1171
1172
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
1173
            and issubclass(p.annotation, msgspec.Struct)
1174
1175
1176
1177
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
1178

1179
1180
1181
1182
1183
1184
1185
1186
1187
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
1188
1189
1190
1191
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
1192

1193
1194
1195
    def process_input_sockets(
        self,
        input_addresses: list[str],
1196
        coord_input_address: str | None,
1197
1198
1199
        identity: bytes,
        ready_event: threading.Event,
    ):
1200
1201
1202
        """Input socket IO thread."""

        # Msgpack serialization decoding.
1203
1204
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
1205

1206
1207
1208
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
1209
1210
1211
1212
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
1213
1214
1215
1216
1217
1218
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
1219
1220
1221
1222
1223
1224
1225
1226
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
1227
                # Send subscription message to coordinator.
1228
                coord_socket.send(b"\x01")
1229
1230
1231
1232
1233
1234
1235

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
1236
                input_socket.send(b"")
1237
                poller.register(input_socket, zmq.POLLIN)
1238

1239
            if coord_socket is not None:
1240
1241
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
1242
                poller.register(coord_socket, zmq.POLLIN)
1243

1244
1245
            ready_event.set()
            del ready_event
1246
1247
1248
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
1249
1250
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
1251
1252

                    # Deserialize the request data.
1253
                    request: Any
1254
                    if request_type == EngineCoreRequestType.ADD:
1255
1256
1257
1258
1259
1260
                        req: EngineCoreRequest = add_request_decoder.decode(data_frames)
                        try:
                            request = self.preprocess_add_request(req)
                        except Exception:
                            self._handle_request_preproc_error(req)
                            continue
1261
1262
                    else:
                        request = generic_decoder.decode(data_frames)
1263

1264
1265
1266
1267
1268
1269
1270
                        if request_type == EngineCoreRequestType.ABORT:
                            # Aborts are added to *both* queues, allows us to eagerly
                            # process aborts while also ensuring ordering in the input
                            # queue to avoid leaking requests. This is ok because
                            # aborting in the scheduler is idempotent.
                            self.aborts_queue.put_nowait(request)

1271
1272
1273
                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

1274
1275
1276
    def process_output_sockets(
        self,
        output_paths: list[str],
1277
        coord_output_path: str | None,
1278
1279
        engine_index: int,
    ):
1280
1281
1282
        """Output socket IO thread."""

        # Msgpack serialization encoding.
1283
        encoder = MsgpackEncoder()
1284
1285
1286
1287
1288
1289
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
1290

1291
1292
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1293
1294
1295
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1296
1297
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1298
1299
                for output_path in output_paths
            ]
1300
1301
1302
1303
1304
1305
1306
1307
1308
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1309
1310
            max_reuse_bufs = len(sockets) + 1

1311
            while True:
1312
1313
1314
1315
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1316
                    break
1317
1318
                assert not isinstance(output, bytes)
                client_index, outputs = output
1319
                outputs.engine_index = engine_index
1320

1321
1322
1323
1324
1325
1326
1327
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1328
1329
1330
1331
1332
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1333
                buffers = encoder.encode_into(outputs, buffer)
1334
1335
1336
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1337
1338
1339
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1340
1341
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1342
                    reuse_buffers.append(buffer)
1343

1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
    def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None:
        """Log and return a request-scoped error response for exceptions raised
        from the add request preprocessing in the input socket processing thread.
        """
        logger.exception(
            "Unexpected error pre-processing request %s", request.request_id
        )
        self.output_queue.put_nowait(
            (
                request.client_index,
                EngineCoreOutputs(
                    engine_index=self.engine_index,
                    finished_requests={request.request_id},
                    outputs=[
                        EngineCoreOutput(
                            request_id=request.request_id,
                            new_token_ids=[],
                            finish_reason=FinishReason.ERROR,
                        )
                    ],
                ),
            )
        )

1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
    def pause_scheduler(
        self, mode: PauseMode = "abort", clear_cache: bool = True
    ) -> Future | None:
        """Pause generation; behavior depends on mode.

        All pause states queue new adds. PAUSE_ABORT and PAUSE_KEEP skip step();
        PAUSE_WAIT allows step() so in-flight requests can drain.

        - ``abort``: Set PAUSE_ABORT, abort all requests, wait for abort
          outputs to be sent (when running with output_queue), clear caches,
          then complete the returned Future.
        - ``wait``: Set PAUSE_WAIT (queue adds, keep stepping); when drained,
          set PAUSE_KEEP, clear caches, complete the returned Future.
        - ``keep``: Set PAUSE_KEEP; return a Future that completes when the
          output queue is empty.
        """
        if mode not in ("keep", "abort", "wait"):
            raise ValueError(f"Invalid pause mode: {mode}")

        future: Future[Any] = Future()

        def wait_until_idle(engine: "EngineCoreProc") -> bool:
            scheduler = engine.scheduler
            out_queue = engine.output_queue
            if scheduler.has_requests() or engine.batch_queue or not out_queue.empty():
                return False
            if clear_cache:
                engine.reset_prefix_cache(reset_running_requests=True)
                engine.reset_mm_cache()
                engine.reset_encoder_cache()
            future.set_result(None)
            return True

        if mode == "abort":
            aborted_reqs = self.scheduler.finish_requests(
                None, RequestStatus.FINISHED_ABORTED
            )
            self._send_abort_outputs(aborted_reqs)

        pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
        self.scheduler.set_pause_state(pause_state)
        if not wait_until_idle(self):
            self.per_step_hooks.add(wait_until_idle)
            return future
        return None

    def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None:
        if aborted_reqs:
            # Map client_index to list of request_ids that belong to that client.
            by_client = defaultdict[int, set[str]](set)
            for req_id, client_index in aborted_reqs:
                by_client[client_index].add(req_id)
            for client_index, req_ids in by_client.items():
                outputs = [
                    EngineCoreOutput(req_id, [], finish_reason=FinishReason.ABORT)
                    for req_id in req_ids
                ]
                eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
                self.output_queue.put_nowait((client_index, eco))

    def resume_scheduler(self) -> None:
        """Resume the scheduler and flush any requests queued while paused."""
        self.scheduler.set_pause_state(PauseState.UNPAUSED)

    def is_scheduler_paused(self) -> bool:
        """Return whether the scheduler is in any pause state."""
        return self.scheduler.pause_state != PauseState.UNPAUSED

1436
1437
1438
1439
1440
1441
1442
1443

class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1444
        local_client: bool,
1445
        handshake_address: str,
1446
1447
        executor_class: type[Executor],
        log_stats: bool,
1448
        client_handshake_address: str | None = None,
1449
    ):
1450
1451
1452
1453
        assert vllm_config.model_config.is_moe, (
            "DPEngineCoreProc should only be used for MoE models"
        )

1454
1455
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1456
        self.step_counter = 0
1457
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1458
        self.last_counts = (0, 0)
1459
1460
1461

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1462
1463
1464
1465
1466
1467
1468
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
1469
            engine_index=dp_rank,
1470
        )
1471
1472
1473

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1474
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1475
        dp_size = vllm_config.parallel_config.data_parallel_size
1476
1477
1478
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
1479
        assert local_dp_rank is not None
1480
1481
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1482
        self.dp_rank = dp_rank
1483
1484
1485
1486
1487
1488
1489
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1490
1491
1492
1493
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1494
1495
1496
1497
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1498
1499
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1500

1501
        super().add_request(request, request_wave)
1502

1503
1504
1505
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1506
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1507
1508
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1509
1510
                new_wave >= self.current_wave
            ):
1511
1512
                self.current_wave = new_wave
                if not self.engines_running:
1513
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1514
1515
1516
1517
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1518
    def _maybe_publish_request_counts(self):
1519
        if not self.publish_dp_lb_stats:
1520
1521
1522
1523
1524
1525
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1526
1527
1528
1529
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1530

1531
1532
1533
1534
1535
1536
1537
1538
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1539
1540
            # 2) Step the engine core.
            executed = self._process_engine_step()
1541
            self._maybe_publish_request_counts()
1542
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1543

1544
1545
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1546
1547
1548
                    # All engines are idle.
                    continue

1549
1550
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1551
1552
1553
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1554
            self.engines_running = self._has_global_unfinished_reqs(
1555
1556
                local_unfinished_reqs
            )
1557

1558
            if not self.engines_running:
1559
                if self.dp_rank == 0 or not self.has_coordinator:
1560
                    # Notify client that we are pausing the loop.
1561
1562
1563
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1564
1565
1566
1567
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1568
                    self.output_queue.put_nowait(
1569
1570
1571
1572
1573
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1574
                # Increment wave count and reset step counter.
1575
                self.current_wave += 1
1576
                self.step_counter = 0
1577
1578

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1579
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1580
1581
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1582
1583
            return True

1584
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1585

1586
    def reinitialize_distributed(
1587
1588
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1589
1590
1591
1592
1593
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1594
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1595
        if reconfig_request.new_data_parallel_rank != -1:
1596
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1597
        # local rank specifies device visibility, it should not be changed
1598
1599
1600
1601
1602
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1603
            reconfig_request.new_data_parallel_master_ip
1604
1605
        )
        parallel_config.data_parallel_master_port = (
1606
            reconfig_request.new_data_parallel_master_port
1607
        )
1608
1609
1610
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1611
        reconfig_request.new_data_parallel_master_port = (
1612
            parallel_config.data_parallel_master_port
1613
        )
1614
1615
1616
1617
1618
1619
1620
1621

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1622
1623
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1624
1625
1626
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1627
1628
1629
1630
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1631
1632
1633
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1634
1635
1636
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1637

Rui Qiao's avatar
Rui Qiao committed
1638

1639
class EngineCoreActorMixin:
Rui Qiao's avatar
Rui Qiao committed
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        addresses: EngineZmqAddresses,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
1651
1652
1653
1654
1655
1656
1657
        # Initialize tracer for distributed tracing if configured.
        maybe_init_worker_tracer(
            instrumenting_module_name="vllm.engine_core",
            process_kind="engine_core",
            process_name=f"DPEngineCoreActor_DP{dp_rank}",
        )

Rui Qiao's avatar
Rui Qiao committed
1658
        self.addresses = addresses
1659
        vllm_config.parallel_config.data_parallel_index = dp_rank
1660
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1661

1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1672
1673
1674
1675
1676
1677
1678
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1679
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1680

1681
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1682
        from vllm.platforms import current_platform
1683

1684
1685
1686
1687
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1688
1689
1690
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1691

1692
1693
1694
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1695
1696
1697
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1698
1699
1700
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1701
            os.environ[device_control_env_var] = value
1702
1703
1704
1705
1706
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1707
1708
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1709

Rui Qiao's avatar
Rui Qiao committed
1710
    @contextmanager
1711
1712
1713
1714
1715
1716
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1717
        client_handshake_address: str | None,
1718
    ):
Rui Qiao's avatar
Rui Qiao committed
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
1741
            self.run_busy_loop()  # type: ignore[attr-defined]
Rui Qiao's avatar
Rui Qiao committed
1742
1743
1744
1745
1746
1747
1748
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
            self.shutdown()  # type: ignore[attr-defined]


class DPMoEEngineCoreActor(EngineCoreActorMixin, DPEngineCoreProc):
    """Used for MoE model data parallel cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_rank = dp_rank

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        DPEngineCoreProc.__init__(
            self, vllm_config, local_client, "", executor_class, log_stats
        )


class EngineCoreActor(EngineCoreActorMixin, EngineCoreProc):
    """Used for non-MoE and/or non-DP cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_size = 1
        vllm_config.parallel_config.data_parallel_size_local = 1
        vllm_config.parallel_config.data_parallel_rank = 0

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        EngineCoreProc.__init__(
            self,
            vllm_config,
            local_client,
            "",
            executor_class,
            log_stats,
            engine_index=dp_rank,
        )