core.py 71.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
7
import threading
import time
8
from collections import defaultdict, deque
9
from collections.abc import Callable, Generator
10
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
11
from contextlib import ExitStack, contextmanager
12
from inspect import isclass, signature
13
from logging import DEBUG
14
from typing import Any, TypeVar, cast
15

16
import msgspec
17
18
import zmq

19
20
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
21
from vllm.envs import enable_envs_cache
22
from vllm.logger import init_logger
23
from vllm.logging_utils.dump_input import dump_engine_exception
24
from vllm.lora.request import LoRARequest
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.tasks import POOLING_TASKS, SupportedTask
27
from vllm.tracing import instrument, maybe_init_worker_tracer
28
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
29
30
31
32
from vllm.utils.gc_utils import (
    freeze_gc_heap,
    maybe_attach_gc_debug_callback,
)
33
from vllm.utils.hashing import get_hash_fn_by_name
34
from vllm.utils.network_utils import make_zmq_socket
35
from vllm.utils.system_utils import decorate_logs, set_process_title
36
37
38
39
40
41
42
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
43
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
44
from vllm.v1.core.sched.output import SchedulerOutput
45
from vllm.v1.engine import (
46
    EngineCoreOutput,
47
48
49
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
50
    FinishReason,
51
    PauseMode,
52
53
54
55
56
57
58
59
60
61
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
62
from vllm.v1.executor import Executor
63
from vllm.v1.kv_cache_interface import KVCacheConfig
64
from vllm.v1.metrics.stats import SchedulerStats
65
from vllm.v1.outputs import ModelRunnerOutput
66
from vllm.v1.request import Request, RequestStatus
67
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
68
from vllm.v1.structured_output import StructuredOutputManager
69
from vllm.v1.utils import compute_iteration_details
70
71
72
73
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

74
POLLING_TIMEOUT_S = 2.5
75
HANDSHAKE_TIMEOUT_MINS = 5
76

77
_R = TypeVar("_R")  # Return type for collective_rpc
78

79
80
81
82

class EngineCore:
    """Inner loop of vLLM's Engine."""

83
84
85
86
87
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
88
        executor_fail_callback: Callable | None = None,
89
        include_finished_set: bool = False,
90
    ):
91
92
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
93

94
95
        load_general_plugins()

96
        self.vllm_config = vllm_config
97
        if not vllm_config.parallel_config.data_parallel_rank_local:
98
99
100
101
102
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
103

104
105
        self.log_stats = log_stats

106
107
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
108
        if executor_fail_callback is not None:
109
            self.model_executor.register_failure_callback(executor_fail_callback)
110

111
112
        self.available_gpu_memory_for_kv_cache = -1

113
        # Setup KV Caches and update CacheConfig after profiling.
114
115
116
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
117

118
119
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
120
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
121

122
123
        self.structured_output_manager = StructuredOutputManager(vllm_config)

124
        # Setup scheduler.
125
        Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
126

127
        if len(kv_cache_config.kv_cache_groups) == 0:  # noqa: SIM102
128
129
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
130
131
132
            if vllm_config.scheduler_config.enable_chunked_prefill:
                logger.warning("Disabling chunked prefill for model without KVCache")
                vllm_config.scheduler_config.enable_chunked_prefill = False
133

134
135
136
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
137
            * vllm_config.parallel_config.prefill_context_parallel_size
138
139
        )

140
        self.scheduler: SchedulerInterface = Scheduler(
141
            vllm_config=vllm_config,
142
143
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
144
            include_finished_set=include_finished_set,
145
            log_stats=self.log_stats,
146
            block_size=scheduler_block_size,
147
        )
148
        self.use_spec_decode = vllm_config.speculative_config is not None
149
        if self.scheduler.connector is not None:  # type: ignore
150
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
151

152
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
153
154
        self.mm_receiver_cache = mm_registry.engine_receiver_cache_from_config(
            vllm_config
155
        )
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        # If a KV connector is initialized for scheduler, we want to collect
        # handshake metadata from all workers so the connector in the scheduler
        # will have the full context
        kv_connector = self.scheduler.get_kv_connector()
        if kv_connector is not None:
            # Collect and store KV connector xfer metadata from workers
            # (after KV cache registration)
            xfer_handshake_metadata = (
                self.model_executor.get_kv_connector_handshake_metadata()
            )

            if xfer_handshake_metadata:
                # xfer_handshake_metadata is list of dicts from workers
                # Each dict already has structure {tp_rank: metadata}
                # Merge all worker dicts into a single dict
                content: dict[int, Any] = {}
                for worker_dict in xfer_handshake_metadata:
                    if worker_dict is not None:
                        content.update(worker_dict)
                kv_connector.set_xfer_handshake_metadata(content)

178
179
180
181
182
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
183
        self.batch_queue: (
184
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput, Future[Any]]] | None
185
        ) = None
186
        if self.batch_queue_size > 1:
187
            logger.debug("Batch queue is enabled with size %d", self.batch_queue_size)
188
            self.batch_queue = deque(maxlen=self.batch_queue_size)
189

190
        self.is_ec_producer = (
191
192
193
            vllm_config.ec_transfer_config is not None
            and vllm_config.ec_transfer_config.is_ec_producer
        )
194
        self.is_pooling_model = vllm_config.model_config.runner_type == "pooling"
195

196
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
197
        if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
198
            caching_hash_fn = get_hash_fn_by_name(
199
200
                vllm_config.cache_config.prefix_caching_hash_algo
            )
201
202
203
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
204
                scheduler_block_size, caching_hash_fn
205
            )
206

207
208
209
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
210
        self.async_scheduling = vllm_config.scheduler_config.async_scheduling
211

212
        self.aborts_queue = queue.Queue[list[str]]()
213

214
        self.per_step_hooks: set[Callable] = set()
215

216
217
218
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        freeze_gc_heap()
219
220
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()
221
222
223
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()
224

225
    @instrument(span_name="Prepare model")
226
    def _initialize_kv_caches(
227
228
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
229
        start = time.time()
230

231
        # Get all kv cache needed by the model
232
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
233

234
235
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
236
237
238
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
239
                self.available_gpu_memory_for_kv_cache = (
240
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
241
242
243
244
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
245
246
247
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
248
249
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
250
251
252
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
253

254
        assert len(kv_cache_specs) == len(available_gpu_memory)
255

256
257
258
        # Track max_model_len before KV cache config to detect auto-fit changes
        max_model_len_before = vllm_config.model_config.max_model_len

259
260
261
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
262
263
264
265
266
267
268
269

        # If auto-fit reduced max_model_len, sync the new value to workers.
        # This is needed because workers were spawned before memory profiling
        # and have the original (larger) max_model_len cached.
        max_model_len_after = vllm_config.model_config.max_model_len
        if max_model_len_after != max_model_len_before:
            self.collective_rpc("update_max_model_len", args=(max_model_len_after,))

270
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
271
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
272
        num_cpu_blocks = 0
273
274

        # Initialize kv cache and warmup the execution
275
        self.model_executor.initialize_from_config(kv_cache_configs)
276

277
        elapsed = time.time() - start
278
        logger.info_once(
279
            "init engine (profile, create kv cache, warmup model) took %.2f seconds",
280
            elapsed,
281
            scope="local",
282
        )
283
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
284

285
286
287
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

288
289
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
290

291
292
293
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
294
295
296
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
297
298
                f"request_id must be a string, got {type(request.request_id)}"
            )
299

300
        if pooling_params := request.pooling_params:
301
            supported_pooling_tasks = [
302
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
303
304
            ]

305
            if pooling_params.task not in supported_pooling_tasks:
306
307
308
309
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
310

311
        if request.kv_transfer_params is not None and (
312
313
314
315
316
317
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
318

319
        self.scheduler.add_request(request)
320

321
    def abort_requests(self, request_ids: list[str]):
322
323
324
325
326
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
327
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
328

329
330
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
331
        """Execute the model and log detailed info on failure."""
332
        try:
333
            yield
334
335
336
337
338
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

339
            # NOTE: This method is exception-free
340
341
342
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
343
344
            raise err

345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    @contextmanager
    def log_iteration_details(self, scheduler_output: SchedulerOutput):
        if not self.vllm_config.observability_config.enable_logging_iteration_details:
            yield
            return
        self._iteration_index = getattr(self, "_iteration_index", 0)
        iteration_details = compute_iteration_details(scheduler_output)
        before = time.monotonic()
        yield
        logger.info(
            "".join(
                [
                    "Iteration(",
                    str(self._iteration_index),
                    "): ",
                    str(iteration_details.num_ctx_requests),
                    " context requests, ",
                    str(iteration_details.num_ctx_tokens),
                    " context tokens, ",
                    str(iteration_details.num_generation_requests),
                    " generation requests, ",
                    str(iteration_details.num_generation_tokens),
                    " generation tokens, iteration elapsed time: ",
                    format((time.monotonic() - before) * 1000, ".2f"),
                    " ms",
                ]
            )
        )
        self._iteration_index += 1

375
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
376
377
378
379
380
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
381

382
383
384
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
385
            return {}, False
386
387
388
        scheduler_output = self.scheduler.schedule()
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
389
390
391
392
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
393
394
395
396
            model_output = future.result()
            if model_output is None:
                model_output = self.model_executor.sample_tokens(grammar_output)

397
398
399
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
400
401
402
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
403

404
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
405

406
    def post_step(self, model_executed: bool) -> None:
407
408
409
410
        # When using async scheduling we can't get draft token ids in advance,
        # so we update draft token ids in the worker process and don't
        # need to update draft token ids here.
        if not self.async_scheduling and self.use_spec_decode and model_executed:
411
412
413
414
415
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

416
    def step_with_batch_queue(
417
        self,
418
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
419
420
421
422
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
423
424
425
426
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
427
428
429
430
431
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
432

433
434
        batch_queue = self.batch_queue
        assert batch_queue is not None
435

436
437
438
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
439
        assert len(batch_queue) < self.batch_queue_size
440

441
        model_executed = False
442
        deferred_scheduler_output = None
443
        if self.scheduler.has_requests():
444
445
446
447
            scheduler_output = self.scheduler.schedule()
            exec_future = self.model_executor.execute_model(
                scheduler_output, non_block=True
            )
448
            if not self.is_ec_producer:
449
                model_executed = scheduler_output.total_num_scheduled_tokens > 0
450

451
            if self.is_pooling_model or not model_executed:
452
453
                # No sampling required (no requests scheduled).
                future = cast(Future[ModelRunnerOutput], exec_future)
454
            else:
455
456
457
                if not scheduler_output.pending_structured_output_tokens:
                    # We aren't waiting for any tokens, get any grammar output
                    # and sample immediately.
458
459
460
                    grammar_output = self.scheduler.get_grammar_bitmask(
                        scheduler_output
                    )
461
462
463
                    future = self.model_executor.sample_tokens(
                        grammar_output, non_block=True
                    )
464
                else:
465
466
467
468
469
                    # We need to defer sampling until we have processed the model output
                    # from the prior step.
                    deferred_scheduler_output = scheduler_output

            if not deferred_scheduler_output:
470
                # Add this step's future to the queue.
471
                batch_queue.appendleft((future, scheduler_output, exec_future))
472
473
474
475
476
477
478
479
                if (
                    model_executed
                    and len(batch_queue) < self.batch_queue_size
                    and not batch_queue[-1][0].done()
                ):
                    # Don't block on next worker response unless the queue is full
                    # or there are no more requests to schedule.
                    return None, True
480
481
482
483
484
485

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False
486
487

        # Block until the next result is available.
488
        future, scheduler_output, exec_model_fut = batch_queue.pop()
489
490
491
492
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
493
            model_output = future.result()
494
495
496
497
498
            if model_output is None:
                # None from sample_tokens() implies that the original execute_model()
                # call failed - raise that exception.
                exec_model_fut.result()
                raise RuntimeError("unexpected error")
499

500
501
502
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
503
504
505
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
506
507
508
509
510

        # NOTE(nick): We can either handle the deferred tasks here or save
        # in a field and do it immediately once step_with_batch_queue is
        # re-called. The latter slightly favors TTFT over TPOT/throughput.
        if deferred_scheduler_output:
511
512
513
514
515
516
517
518
519
520
521
522
            # If we are doing speculative decoding with structured output,
            # we need to get the draft token ids from the prior step before
            # we can compute the grammar bitmask for the deferred request.
            if self.use_spec_decode:
                draft_token_ids = self.model_executor.take_draft_token_ids()
                assert draft_token_ids is not None
                # Update the draft token ids in the scheduler output to
                # filter out the invalid spec tokens, which will be padded
                # with -1 and skipped by the grammar bitmask computation.
                self.scheduler.update_draft_token_ids_in_output(
                    draft_token_ids, deferred_scheduler_output
                )
523
524
525
526
527
528
            # We now have the tokens needed to compute the bitmask for the
            # deferred request. Get the bitmask and call sample tokens.
            grammar_output = self.scheduler.get_grammar_bitmask(
                deferred_scheduler_output
            )
            future = self.model_executor.sample_tokens(grammar_output, non_block=True)
529
            batch_queue.appendleft((future, deferred_scheduler_output, exec_future))
530

531
        return engine_core_outputs, model_executed
532

533
534
535
536
537
    def _process_aborts_queue(self):
        if not self.aborts_queue.empty():
            request_ids = []
            while not self.aborts_queue.empty():
                ids = self.aborts_queue.get_nowait()
538
539
                # Should be a list here, but also handle string just in case.
                request_ids.extend((ids,) if isinstance(ids, str) else ids)
540
541
542
            # More efficient to abort all as a single batch.
            self.abort_requests(request_ids)

543
    def shutdown(self):
544
        self.structured_output_manager.clear_backend()
545
546
        if self.model_executor:
            self.model_executor.shutdown()
547
548
        if self.scheduler:
            self.scheduler.shutdown()
549

550
551
    def profile(self, is_start: bool = True, profile_prefix: str | None = None):
        self.model_executor.profile(is_start, profile_prefix)
552

553
554
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
555
        # re-sync the internal caches (P0 sender, P1 receiver)
556
        if self.scheduler.has_unfinished_requests():
557
558
559
560
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
561

562
        # The cache either exists in EngineCore or WorkerWrapperBase
563
564
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
565

566
567
        self.model_executor.reset_mm_cache()

568
569
570
571
572
573
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.scheduler.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
574

575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    def reset_encoder_cache(self) -> None:
        """Reset the encoder cache to invalidate all cached encoder outputs.

        This should be called when model weights are updated to ensure
        stale vision embeddings computed with old weights are not reused.
        Clears both the scheduler's cache manager and the GPU model runner's cache.
        """
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 sender, P1 receiver)
        if self.scheduler.has_unfinished_requests():
            logger.warning(
                "Resetting the encoder cache when requests are "
                "in progress may lead to desynced internal caches."
            )

        # Reset the scheduler's encoder cache manager (logical state)
        self.scheduler.reset_encoder_cache()
        # Reset the GPU model runner's encoder cache (physical storage)
        self.model_executor.reset_encoder_cache()

595
596
597
598
599
600
601
602
603
604
605
606
607
608
    def pause_scheduler(
        self, mode: PauseMode = "abort", clear_cache: bool = True
    ) -> Future[Any] | None:
        """Pause scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""
        return None

    def resume_scheduler(self) -> None:
        """Resume scheduling. No-op in base EngineCore; overridden in EngineCoreProc."""

    def is_scheduler_paused(self) -> bool:
        """Return whether the scheduler is in any pause state. False in base EngineCore
        and overridden in EngineCoreProc."""
        return False

609
    def sleep(self, level: int = 1):
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
        """Put the engine to sleep at the specified level.

        Args:
            level: Sleep level.
                - Level 0: Pause scheduling only. Requests are still accepted
                           but not processed. No GPU memory changes.
                - Level 1: Offload model weights to CPU, discard KV cache.
                - Level 2: Discard all GPU memory.
        """
        if level == 0:
            # Level 0: Just pause scheduling, don't touch GPU
            self.pause_scheduler()
        else:
            # Level 1+: Delegate to executor for GPU memory management
            self.model_executor.sleep(level)
625

626
    def wake_up(self, tags: list[str] | None = None):
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
        """Wake up the engine from sleep.

        Args:
            tags: Tags to wake up. Use ["scheduling"] for level 0 wake up.
        """
        if tags is not None and "scheduling" in tags:
            # Level 0 wake up: Resume scheduling
            self.resume_scheduler()
            # Remove "scheduling" from tags if there are other tags to process
            remaining_tags = [t for t in tags if t != "scheduling"]
            if remaining_tags:
                self.model_executor.wake_up(remaining_tags)
        else:
            # Full wake up
            self.resume_scheduler()
            self.model_executor.wake_up(tags)
643

644
    def is_sleeping(self) -> bool:
645
        """Check if engine is sleeping at any level."""
646
        return self.is_scheduler_paused() or self.model_executor.is_sleeping
647

648
    def execute_dummy_batch(self):
649
        self.model_executor.execute_dummy_batch()
650

651
652
653
654
655
656
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

657
    def list_loras(self) -> set[int]:
658
659
660
661
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
662

663
664
665
    def save_sharded_state(
        self,
        path: str,
666
667
        pattern: str | None = None,
        max_size: int | None = None,
668
    ) -> None:
669
670
671
672
673
674
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
675
676
        method: str | Callable[..., _R],
        timeout: float | None = None,
677
        args: tuple = (),
678
        kwargs: dict[str, Any] | None = None,
679
680
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
681

682
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
683
        """Preprocess the request.
684

685
686
687
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
688
689
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
690
        # and will only be accessed in the input processing thread afterwards.
691
        if self.mm_receiver_cache is not None and request.mm_features:
692
693
694
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
695

696
        req = Request.from_engine_core_request(request, self.request_block_hasher)
697
698
699
700
701
702
703
704
705
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

706
707
708
709

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

710
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
711

712
    @instrument(span_name="EngineCoreProc init")
713
714
    def __init__(
        self,
715
        vllm_config: VllmConfig,
716
        local_client: bool,
717
        handshake_address: str,
718
        executor_class: type[Executor],
719
        log_stats: bool,
720
        client_handshake_address: str | None = None,
721
        *,
722
        engine_index: int = 0,
723
    ):
Rui Qiao's avatar
Rui Qiao committed
724
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
725
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
726
        executor_fail_callback = lambda: self.input_queue.put_nowait(
727
728
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
729

Rui Qiao's avatar
Rui Qiao committed
730
731
732
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
733

734
735
736
737
738
739
740
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
741
            self.client_count = len(addresses.outputs)
742
743

            # Set up data parallel environment.
744
            self.has_coordinator = addresses.coordinator_output is not None
745
            self.frontend_stats_publish_address = (
746
747
748
749
750
751
752
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
753
            internal_dp_balancing = (
754
                self.has_coordinator
755
756
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
757
758
759
            # Only publish request queue stats to coordinator for "internal"
            # and "hybrid" LB modes.
            self.publish_dp_lb_stats = internal_dp_balancing
760

761
762
            self._init_data_parallel(vllm_config)

763
            super().__init__(
764
765
766
767
768
                vllm_config,
                executor_class,
                log_stats,
                executor_fail_callback,
                internal_dp_balancing,
769
            )
770

771
772
773
774
775
776
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
777
778
779
780
781
782
783
784
785
786
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
787
788
789
790
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
791
792
793
794
795
796
797
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
798
799
800
801
802
803
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
804
                    raise RuntimeError("Input socket thread died during startup")
805
806
807
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

Rui Qiao's avatar
Rui Qiao committed
808
    @contextmanager
809
810
811
812
813
814
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
815
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
816
    ) -> Generator[EngineZmqAddresses, None, None]:
817
818
819
820
821
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

822
        For DP>1 with internal load-balancing this is with the shared front-end
823
824
        process which may reside on a different node.

825
        For DP>1 with external or hybrid load-balancing, two handshakes are
826
        performed:
827
828
829
830
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
831
832
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
833
834
835
836
837
838

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
839
        input_ctx = zmq.Context()
840
        is_local = local_client and client_handshake_address is None
841
        headless = not local_client
842
843
844
845
846
847
848
849
850
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
851
852
853
854
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
855
            assert local_client
856
            local_handshake = self._perform_handshake(
857
858
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
874
        headless: bool,
875
        vllm_config: VllmConfig,
876
        parallel_config_to_update: ParallelConfig | None = None,
877
    ) -> Generator[EngineZmqAddresses, None, None]:
878
879
880
881
882
883
884
885
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
886
            # Register engine with front-end.
887
888
889
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
890
891
892
893
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
894
895
896
897
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
898
899
900
901
902
903
904
905
906
907
908
909

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
910
                )
911
912

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
913

914
    @staticmethod
915
    def startup_handshake(
916
917
        handshake_socket: zmq.Socket,
        local_client: bool,
918
        headless: bool,
919
        parallel_config: ParallelConfig | None = None,
920
    ) -> EngineZmqAddresses:
921
        # Send registration message.
922
        handshake_socket.send(
923
924
925
926
927
928
929
930
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
931
932

        # Receive initialization message.
933
        logger.debug("Waiting for init message from front-end.")
934
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
935
936
937
938
939
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
940
941
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
942
943
            init_bytes, type=EngineHandshakeMetadata
        )
944
945
        logger.debug("Received init message: %s", init_message)

946
947
948
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
949

950
        return init_message.addresses
951
952

    @staticmethod
953
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
954
955
        """Launch EngineCore busy loop in background process."""

956
957
958
959
960
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

961
962
963
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

964
965
966
967
968
969
970
971
972
973
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

974
        engine_core: EngineCoreProc | None = None
975
        try:
976
977
978
979
980
            vllm_config: VllmConfig = kwargs["vllm_config"]
            parallel_config: ParallelConfig = vllm_config.parallel_config
            data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
            if data_parallel:
                parallel_config.data_parallel_rank_local = local_dp_rank
981
982
983
984
985
                maybe_init_worker_tracer(
                    instrumenting_module_name="vllm.engine_core",
                    process_kind="engine_core",
                    process_name=f"EngineCore_DP{dp_rank}",
                )
986
                set_process_title("EngineCore", f"DP{dp_rank}")
987
            else:
988
989
990
991
992
                maybe_init_worker_tracer(
                    instrumenting_module_name="vllm.engine_core",
                    process_kind="engine_core",
                    process_name="EngineCore",
                )
993
994
995
                set_process_title("EngineCore")
            decorate_logs()

996
997
998
999
1000
1001
1002
1003
1004
1005
1006
            if data_parallel and vllm_config.kv_transfer_config is not None:
                # modify the engine_id and append the local_dp_rank to it to ensure
                # that the kv_transfer_config is unique for each DP rank.
                vllm_config.kv_transfer_config.engine_id = (
                    f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
                )
                logger.debug(
                    "Setting kv_transfer_config.engine_id to %s",
                    vllm_config.kv_transfer_config.engine_id,
                )

1007
1008
            parallel_config.data_parallel_index = dp_rank
            if data_parallel and vllm_config.model_config.is_moe:
1009
1010
1011
1012
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
1013
1014
1015
1016
1017
1018
1019
                # Non-MoE DP ranks are completely independent, so treat like DP=1.
                # Note that parallel_config.data_parallel_index will still reflect
                # the original DP rank.
                parallel_config.data_parallel_size = 1
                parallel_config.data_parallel_size_local = 1
                parallel_config.data_parallel_rank = 0
                engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
1020

1021
            assert engine_core is not None
1022
1023
            engine_core.run_busy_loop()

1024
        except SystemExit:
1025
            logger.debug("EngineCore exiting.")
1026
            raise
1027
1028
1029
1030
1031
1032
1033
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
1034
1035
1036
1037
        finally:
            if engine_core is not None:
                engine_core.shutdown()

1038
1039
1040
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

1041
1042
1043
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

1044
1045
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
1046
            # 1) Poll the input queue until there is work to do.
1047
1048
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
1049
1050
1051
            self._process_engine_step()
            # 3) Run any per-step hooks.
            self._process_per_step_hooks()
1052
1053
1054
1055
1056

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
1057
1058
        while (
            not self.engines_running
1059
            and not self.scheduler.has_requests()
1060
            and not self.batch_queue
1061
            and not self.per_step_hooks
1062
        ):
1063
1064
1065
1066
1067
1068
1069
            if self.input_queue.empty():
                # Drain aborts queue; all aborts are also processed via input_queue.
                with self.aborts_queue.mutex:
                    self.aborts_queue.queue.clear()
                if logger.isEnabledFor(DEBUG):
                    logger.debug("EngineCore waiting for work.")
                    waited = True
1070
1071
1072
1073
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
1074
            logger.debug("EngineCore loop active.")
1075
1076
1077
1078
1079
1080

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

1081
    def _process_engine_step(self) -> bool:
1082
1083
1084
        """Called only when there are unfinished local requests."""

        # Step the engine core.
1085
        outputs, model_executed = self.step_fn()
1086
        # Put EngineCoreOutputs into the output queue.
1087
        for output in outputs.items() if outputs else ():
1088
            self.output_queue.put_nowait(output)
1089
1090
        # Post-step hook.
        self.post_step(model_executed)
1091

1092
1093
1094
1095
1096
1097
1098
        # If no model execution happened but there are waiting requests
        # (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow
        # background threads (like NIXL handshake) to make progress.
        # Without this, the tight polling loop can starve background threads.
        if not model_executed and self.scheduler.has_unfinished_requests():
            time.sleep(0.001)

1099
1100
        return model_executed

1101
1102
1103
1104
1105
1106
1107
    def _process_per_step_hooks(self) -> None:
        if self.per_step_hooks:
            for hook in list(self.per_step_hooks):
                finished = hook(self)
                if finished:
                    self.per_step_hooks.discard(hook)

1108
1109
1110
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1111
        """Dispatch request from client."""
1112

1113
        if request_type == EngineCoreRequestType.ADD:
1114
1115
            req, request_wave = request
            self.add_request(req, request_wave)
1116
        elif request_type == EngineCoreRequestType.ABORT:
1117
            self.abort_requests(request)
1118
        elif request_type == EngineCoreRequestType.UTILITY:
1119
            client_idx, call_id, method_name, args = request
1120
            output = UtilityOutput(call_id)
1121
1122
1123
1124
1125
1126
            # Lazily look-up utility method so that failure will be handled/returned.
            get_result = lambda: (method := getattr(self, method_name)) and method(
                *self._convert_msgspec_args(method, args)
            )
            enqueue_output = lambda out: self.output_queue.put_nowait(
                (client_idx, EngineCoreOutputs(utility_output=out))
1127
            )
1128
            self._invoke_utility_method(method_name, get_result, output, enqueue_output)
1129
1130
1131
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
1132
1133
1134
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
1135

1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
    @staticmethod
    def _invoke_utility_method(
        name: str, get_result: Callable, output: UtilityOutput, enqueue_output: Callable
    ):
        try:
            result = get_result()
            if isinstance(result, Future):
                # Defer utility output handling until future completion.
                callback = lambda future: EngineCoreProc._invoke_utility_method(
                    name, future.result, output, enqueue_output
                )
                result.add_done_callback(callback)
                return
            output.result = UtilityResult(result)
        except Exception as e:
            logger.exception("Invocation of %s method failed", name)
            output.failure_message = f"Call to {name} method failed: {str(e)}"
        enqueue_output(output)

1155
1156
1157
    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
1158
        arg type, try converting to msgspec object."""
1159
1160
1161
1162
1163
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
1164
1165
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
1166
            and issubclass(p.annotation, msgspec.Struct)
1167
1168
1169
1170
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
1171

1172
1173
1174
1175
1176
1177
1178
1179
1180
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
1181
1182
1183
1184
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
1185

1186
1187
1188
    def process_input_sockets(
        self,
        input_addresses: list[str],
1189
        coord_input_address: str | None,
1190
1191
1192
        identity: bytes,
        ready_event: threading.Event,
    ):
1193
1194
1195
        """Input socket IO thread."""

        # Msgpack serialization decoding.
1196
1197
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
1198

1199
1200
1201
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
1202
1203
1204
1205
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
1206
1207
1208
1209
1210
1211
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
1212
1213
1214
1215
1216
1217
1218
1219
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
1220
                # Send subscription message to coordinator.
1221
                coord_socket.send(b"\x01")
1222
1223
1224
1225
1226
1227
1228

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
1229
                input_socket.send(b"")
1230
                poller.register(input_socket, zmq.POLLIN)
1231

1232
            if coord_socket is not None:
1233
1234
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
1235
                poller.register(coord_socket, zmq.POLLIN)
1236

1237
1238
            ready_event.set()
            del ready_event
1239
1240
1241
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
1242
1243
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
1244
1245

                    # Deserialize the request data.
1246
                    request: Any
1247
                    if request_type == EngineCoreRequestType.ADD:
1248
1249
1250
1251
1252
1253
                        req: EngineCoreRequest = add_request_decoder.decode(data_frames)
                        try:
                            request = self.preprocess_add_request(req)
                        except Exception:
                            self._handle_request_preproc_error(req)
                            continue
1254
1255
                    else:
                        request = generic_decoder.decode(data_frames)
1256

1257
1258
1259
1260
1261
1262
1263
                        if request_type == EngineCoreRequestType.ABORT:
                            # Aborts are added to *both* queues, allows us to eagerly
                            # process aborts while also ensuring ordering in the input
                            # queue to avoid leaking requests. This is ok because
                            # aborting in the scheduler is idempotent.
                            self.aborts_queue.put_nowait(request)

1264
1265
1266
                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

1267
1268
1269
    def process_output_sockets(
        self,
        output_paths: list[str],
1270
        coord_output_path: str | None,
1271
1272
        engine_index: int,
    ):
1273
1274
1275
        """Output socket IO thread."""

        # Msgpack serialization encoding.
1276
        encoder = MsgpackEncoder()
1277
1278
1279
1280
1281
1282
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
1283

1284
1285
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1286
1287
1288
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1289
1290
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1291
1292
                for output_path in output_paths
            ]
1293
1294
1295
1296
1297
1298
1299
1300
1301
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1302
1303
            max_reuse_bufs = len(sockets) + 1

1304
            while True:
1305
1306
1307
1308
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1309
                    break
1310
1311
                assert not isinstance(output, bytes)
                client_index, outputs = output
1312
                outputs.engine_index = engine_index
1313

1314
1315
1316
1317
1318
1319
1320
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1321
1322
1323
1324
1325
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1326
                buffers = encoder.encode_into(outputs, buffer)
1327
1328
1329
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1330
1331
1332
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1333
1334
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1335
                    reuse_buffers.append(buffer)
1336

1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
    def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None:
        """Log and return a request-scoped error response for exceptions raised
        from the add request preprocessing in the input socket processing thread.
        """
        logger.exception(
            "Unexpected error pre-processing request %s", request.request_id
        )
        self.output_queue.put_nowait(
            (
                request.client_index,
                EngineCoreOutputs(
                    engine_index=self.engine_index,
                    finished_requests={request.request_id},
                    outputs=[
                        EngineCoreOutput(
                            request_id=request.request_id,
                            new_token_ids=[],
                            finish_reason=FinishReason.ERROR,
                        )
                    ],
                ),
            )
        )

1361
1362
1363
1364
1365
    def pause_scheduler(
        self, mode: PauseMode = "abort", clear_cache: bool = True
    ) -> Future | None:
        """Pause generation; behavior depends on mode.

1366
1367
1368
1369
1370
1371
1372
1373
1374
        All pause modes queue new adds -- "abort" and "keep" skip step();
        "wait" allows step() so in-flight requests can drain.

        - ``abort``: Set PAUSED_NEW, abort all requests, wait for abort
          outputs to be sent (when running with output_queue), optionally
          clear caches, then complete the returned Future.
        - ``wait``: Set PAUSED_NEW (queue adds, keep stepping); when drained,
          optionally clear caches, then complete the returned Future.
        - ``keep``: Set PAUSED_ALL; return a Future that completes when the
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
          output queue is empty.
        """
        if mode not in ("keep", "abort", "wait"):
            raise ValueError(f"Invalid pause mode: {mode}")

        future: Future[Any] = Future()

        def wait_until_idle(engine: "EngineCoreProc") -> bool:
            scheduler = engine.scheduler
            out_queue = engine.output_queue
            if scheduler.has_requests() or engine.batch_queue or not out_queue.empty():
                return False
            if clear_cache:
                engine.reset_prefix_cache(reset_running_requests=True)
                engine.reset_mm_cache()
                engine.reset_encoder_cache()
            future.set_result(None)
            return True

        if mode == "abort":
            aborted_reqs = self.scheduler.finish_requests(
                None, RequestStatus.FINISHED_ABORTED
            )
            self._send_abort_outputs(aborted_reqs)

        pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
        self.scheduler.set_pause_state(pause_state)
        if not wait_until_idle(self):
            self.per_step_hooks.add(wait_until_idle)
            return future
        return None

    def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None:
        if aborted_reqs:
            # Map client_index to list of request_ids that belong to that client.
            by_client = defaultdict[int, set[str]](set)
            for req_id, client_index in aborted_reqs:
                by_client[client_index].add(req_id)
            for client_index, req_ids in by_client.items():
                outputs = [
                    EngineCoreOutput(req_id, [], finish_reason=FinishReason.ABORT)
                    for req_id in req_ids
                ]
                eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
                self.output_queue.put_nowait((client_index, eco))

    def resume_scheduler(self) -> None:
        """Resume the scheduler and flush any requests queued while paused."""
        self.scheduler.set_pause_state(PauseState.UNPAUSED)

    def is_scheduler_paused(self) -> bool:
        """Return whether the scheduler is in any pause state."""
        return self.scheduler.pause_state != PauseState.UNPAUSED

1429
1430
1431
1432
1433
1434
1435
1436

class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1437
        local_client: bool,
1438
        handshake_address: str,
1439
1440
        executor_class: type[Executor],
        log_stats: bool,
1441
        client_handshake_address: str | None = None,
1442
    ):
1443
1444
1445
1446
        assert vllm_config.model_config.is_moe, (
            "DPEngineCoreProc should only be used for MoE models"
        )

1447
1448
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1449
        self.step_counter = 0
1450
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1451
        self.last_counts = (0, 0)
1452
1453
1454

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1455
1456
1457
1458
1459
1460
1461
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
1462
            engine_index=dp_rank,
1463
        )
1464
1465
1466

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1467
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1468
        dp_size = vllm_config.parallel_config.data_parallel_size
1469
1470
1471
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
1472
        assert local_dp_rank is not None
1473
1474
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1475
        self.dp_rank = dp_rank
1476
1477
1478
1479
1480
1481
1482
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1483
1484
1485
1486
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1487
1488
1489
1490
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1491
1492
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1493

1494
        super().add_request(request, request_wave)
1495

1496
1497
1498
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1499
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1500
1501
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1502
1503
                new_wave >= self.current_wave
            ):
1504
1505
                self.current_wave = new_wave
                if not self.engines_running:
1506
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1507
1508
1509
1510
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1511
    def _maybe_publish_request_counts(self):
1512
        if not self.publish_dp_lb_stats:
1513
1514
1515
1516
1517
1518
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1519
1520
1521
1522
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1523

1524
1525
1526
1527
1528
1529
1530
1531
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1532
1533
            # 2) Step the engine core.
            executed = self._process_engine_step()
1534
            self._maybe_publish_request_counts()
1535
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1536

1537
1538
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1539
1540
1541
                    # All engines are idle.
                    continue

1542
1543
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1544
1545
1546
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1547
            self.engines_running = self._has_global_unfinished_reqs(
1548
1549
                local_unfinished_reqs
            )
1550

1551
            if not self.engines_running:
1552
                if self.dp_rank == 0 or not self.has_coordinator:
1553
                    # Notify client that we are pausing the loop.
1554
1555
1556
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1557
1558
1559
1560
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1561
                    self.output_queue.put_nowait(
1562
1563
1564
1565
1566
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1567
                # Increment wave count and reset step counter.
1568
                self.current_wave += 1
1569
                self.step_counter = 0
1570
1571

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1572
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1573
1574
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1575
1576
            return True

1577
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1578

1579
    def reinitialize_distributed(
1580
1581
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1582
1583
1584
1585
1586
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1587
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1588
        if reconfig_request.new_data_parallel_rank != -1:
1589
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1590
        # local rank specifies device visibility, it should not be changed
1591
1592
1593
1594
1595
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1596
            reconfig_request.new_data_parallel_master_ip
1597
1598
        )
        parallel_config.data_parallel_master_port = (
1599
            reconfig_request.new_data_parallel_master_port
1600
        )
1601
1602
1603
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1604
        reconfig_request.new_data_parallel_master_port = (
1605
            parallel_config.data_parallel_master_port
1606
        )
1607
1608
1609
1610
1611
1612
1613
1614

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1615
1616
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1617
1618
1619
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1620
1621
1622
1623
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1624
1625
1626
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1627
1628
1629
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1630

Rui Qiao's avatar
Rui Qiao committed
1631

1632
class EngineCoreActorMixin:
Rui Qiao's avatar
Rui Qiao committed
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        addresses: EngineZmqAddresses,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
1644
1645
1646
1647
1648
1649
1650
        # Initialize tracer for distributed tracing if configured.
        maybe_init_worker_tracer(
            instrumenting_module_name="vllm.engine_core",
            process_kind="engine_core",
            process_name=f"DPEngineCoreActor_DP{dp_rank}",
        )

Rui Qiao's avatar
Rui Qiao committed
1651
        self.addresses = addresses
1652
        vllm_config.parallel_config.data_parallel_index = dp_rank
1653
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1654

1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1665
1666
1667
1668
1669
1670
1671
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1672
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1673

1674
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1675
        from vllm.platforms import current_platform
1676

1677
1678
1679
1680
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1681
1682
1683
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1684

1685
1686
1687
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1688
1689
1690
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1691
1692
1693
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1694
            os.environ[device_control_env_var] = value
1695
1696
1697
1698
1699
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1700
1701
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1702

Rui Qiao's avatar
Rui Qiao committed
1703
    @contextmanager
1704
1705
1706
1707
1708
1709
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1710
        client_handshake_address: str | None,
1711
    ):
Rui Qiao's avatar
Rui Qiao committed
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
1734
            self.run_busy_loop()  # type: ignore[attr-defined]
Rui Qiao's avatar
Rui Qiao committed
1735
1736
1737
1738
1739
1740
1741
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
            self.shutdown()  # type: ignore[attr-defined]


class DPMoEEngineCoreActor(EngineCoreActorMixin, DPEngineCoreProc):
    """Used for MoE model data parallel cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_rank = dp_rank

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        DPEngineCoreProc.__init__(
            self, vllm_config, local_client, "", executor_class, log_stats
        )


class EngineCoreActor(EngineCoreActorMixin, EngineCoreProc):
    """Used for non-MoE and/or non-DP cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_size = 1
        vllm_config.parallel_config.data_parallel_size_local = 1
        vllm_config.parallel_config.data_parallel_rank = 0

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        EngineCoreProc.__init__(
            self,
            vllm_config,
            local_client,
            "",
            executor_class,
            log_stats,
            engine_index=dp_rank,
        )