core.py 81 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
7
import threading
import time
8
from collections import defaultdict, deque
9
from collections.abc import Callable, Generator
10
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
11
from contextlib import ExitStack, contextmanager
12
from enum import IntEnum
13
from functools import partial
14
from inspect import isclass, signature
15
from logging import DEBUG
16
from multiprocessing.queues import Queue
17
from typing import Any, TypeVar, cast
18

19
import msgspec
20
21
import zmq

22
import vllm.envs as envs
23
24
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
25
from vllm.envs import enable_envs_cache
26
from vllm.logger import init_logger
27
from vllm.logging_utils.dump_input import dump_engine_exception
28
from vllm.lora.request import LoRARequest
29
from vllm.multimodal import MULTIMODAL_REGISTRY
30
from vllm.tasks import POOLING_TASKS, SupportedTask
31
from vllm.tracing import instrument, maybe_init_worker_tracer
32
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
33
from vllm.utils import numa_utils
34
35
36
37
from vllm.utils.gc_utils import (
    freeze_gc_heap,
    maybe_attach_gc_debug_callback,
)
38
from vllm.utils.hashing import get_hash_fn_by_name
39
from vllm.utils.network_utils import make_zmq_socket
40
from vllm.utils.system_utils import decorate_logs, set_process_title
41
42
43
44
45
46
47
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
48
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
49
from vllm.v1.core.sched.output import SchedulerOutput
50
from vllm.v1.engine import (
51
52
    EEP_NOTIFICATION_CALL_ID,
    EEPNotificationType,
53
    EngineCoreOutput,
54
55
56
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
57
    FinishReason,
58
    PauseMode,
59
60
61
62
63
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
64
from vllm.v1.engine.tensor_ipc import TensorIpcReceiver
65
66
67
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
68
    SignalCallback,
69
70
    get_device_indices,
)
71
from vllm.v1.executor import Executor
72
from vllm.v1.kv_cache_interface import KVCacheConfig
73
from vllm.v1.metrics.stats import SchedulerStats
74
from vllm.v1.outputs import ModelRunnerOutput
75
from vllm.v1.request import Request, RequestStatus
76
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
77
from vllm.v1.structured_output import StructuredOutputManager
78
from vllm.v1.utils import compute_iteration_details
79
80
81
82
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

83
HANDSHAKE_TIMEOUT_MINS = 5
84

85
_R = TypeVar("_R")  # Return type for collective_rpc
86

87
88
89
90

class EngineCore:
    """Inner loop of vLLM's Engine."""

91
92
93
94
95
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
96
        executor_fail_callback: Callable | None = None,
97
        include_finished_set: bool = False,
98
    ):
99
100
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
101

102
103
        load_general_plugins()

104
        self.vllm_config = vllm_config
105
        if not vllm_config.parallel_config.data_parallel_rank_local:
106
107
108
109
110
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
111

112
113
        self.log_stats = log_stats

114
115
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
116
        if executor_fail_callback is not None:
117
            self.model_executor.register_failure_callback(executor_fail_callback)
118

119
120
        self.available_gpu_memory_for_kv_cache = -1

121
122
123
        if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
            self._eep_scale_up_before_kv_init()

124
        # Setup KV Caches and update CacheConfig after profiling.
125
        kv_cache_config = self._initialize_kv_caches(vllm_config)
126
127
        self.structured_output_manager = StructuredOutputManager(vllm_config)

128
        # Setup scheduler.
129
        Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
130

131
        if len(kv_cache_config.kv_cache_groups) == 0:  # noqa: SIM102
132
133
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
134
135
136
            if vllm_config.scheduler_config.enable_chunked_prefill:
                logger.warning("Disabling chunked prefill for model without KVCache")
                vllm_config.scheduler_config.enable_chunked_prefill = False
137

138
139
140
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
141
            * vllm_config.parallel_config.prefill_context_parallel_size
142
143
        )

144
        self.scheduler: SchedulerInterface = Scheduler(
145
            vllm_config=vllm_config,
146
147
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
148
            include_finished_set=include_finished_set,
149
            log_stats=self.log_stats,
150
            block_size=scheduler_block_size,
151
        )
152
        self.use_spec_decode = vllm_config.speculative_config is not None
153
        if self.scheduler.connector is not None:  # type: ignore
154
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
155

156
        mm_registry = MULTIMODAL_REGISTRY
157
158
        self.mm_receiver_cache = mm_registry.engine_receiver_cache_from_config(
            vllm_config
159
        )
160

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        # If a KV connector is initialized for scheduler, we want to collect
        # handshake metadata from all workers so the connector in the scheduler
        # will have the full context
        kv_connector = self.scheduler.get_kv_connector()
        if kv_connector is not None:
            # Collect and store KV connector xfer metadata from workers
            # (after KV cache registration)
            xfer_handshake_metadata = (
                self.model_executor.get_kv_connector_handshake_metadata()
            )

            if xfer_handshake_metadata:
                # xfer_handshake_metadata is list of dicts from workers
                # Each dict already has structure {tp_rank: metadata}
                # Merge all worker dicts into a single dict
                content: dict[int, Any] = {}
                for worker_dict in xfer_handshake_metadata:
                    if worker_dict is not None:
                        content.update(worker_dict)
                kv_connector.set_xfer_handshake_metadata(content)

182
183
184
185
186
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
187
        self.batch_queue: (
188
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput, Future[Any]]] | None
189
        ) = None
190
        if self.batch_queue_size > 1:
191
            logger.debug("Batch queue is enabled with size %d", self.batch_queue_size)
192
            self.batch_queue = deque(maxlen=self.batch_queue_size)
193

194
195
196
        self.is_ec_consumer = (
            vllm_config.ec_transfer_config is None
            or vllm_config.ec_transfer_config.is_ec_consumer
197
        )
198
        self.is_pooling_model = vllm_config.model_config.runner_type == "pooling"
199

200
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
201
        if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
202
            caching_hash_fn = get_hash_fn_by_name(
203
204
                vllm_config.cache_config.prefix_caching_hash_algo
            )
205
206
207
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
208
                scheduler_block_size, caching_hash_fn
209
            )
210

211
212
213
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
214
        self.async_scheduling = vllm_config.scheduler_config.async_scheduling
215

216
        self.aborts_queue = queue.Queue[list[str]]()
217

218
        self._idle_state_callbacks: list[Callable] = []
219

220
221
222
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        freeze_gc_heap()
223
224
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()
225
226
227
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()
228

229
    @instrument(span_name="Prepare model")
230
    def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig:
231
        start = time.time()
232

233
        # Get all kv cache needed by the model
234
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
235

236
237
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
238
239
240
241
            if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
                # NOTE(yongji): should already be set
                # during _eep_scale_up_before_kv_init
                assert self.available_gpu_memory_for_kv_cache > 0
242
243
244
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
245
246
247
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
248
249
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
250
251
252
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
253

254
        assert len(kv_cache_specs) == len(available_gpu_memory)
255

256
257
258
        # Track max_model_len before KV cache config to detect auto-fit changes
        max_model_len_before = vllm_config.model_config.max_model_len

259
260
261
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
262
263
264
265
266
267
268
269

        # If auto-fit reduced max_model_len, sync the new value to workers.
        # This is needed because workers were spawned before memory profiling
        # and have the original (larger) max_model_len cached.
        max_model_len_after = vllm_config.model_config.max_model_len
        if max_model_len_after != max_model_len_before:
            self.collective_rpc("update_max_model_len", args=(max_model_len_after,))

270
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
271
272
273
274
275
276
277
278
        vllm_config.cache_config.num_gpu_blocks = scheduler_kv_cache_config.num_blocks
        kv_cache_groups = scheduler_kv_cache_config.kv_cache_groups
        if kv_cache_groups:
            vllm_config.cache_config.block_size = min(
                g.kv_cache_spec.block_size for g in kv_cache_groups
            )

        vllm_config.validate_block_size()
279
280

        # Initialize kv cache and warmup the execution
281
        self.model_executor.initialize_from_config(kv_cache_configs)
282

283
        elapsed = time.time() - start
284
        logger.info_once(
285
            "init engine (profile, create kv cache, warmup model) took %.2f seconds",
286
            elapsed,
287
            scope="local",
288
        )
289
        return scheduler_kv_cache_config
290

291
292
293
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

294
295
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
296

297
298
299
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
300
301
302
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
303
304
                f"request_id must be a string, got {type(request.request_id)}"
            )
305

306
        if pooling_params := request.pooling_params:
307
            supported_pooling_tasks = [
308
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
309
310
            ]

311
            if pooling_params.task not in supported_pooling_tasks:
312
313
314
315
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
316

317
        if request.kv_transfer_params is not None and (
318
319
320
321
322
323
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
324

325
        self.scheduler.add_request(request)
326

327
    def abort_requests(self, request_ids: list[str]):
328
329
330
331
332
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
333
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
334

335
336
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
337
        """Execute the model and log detailed info on failure."""
338
        try:
339
            yield
340
341
342
343
344
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

345
            # NOTE: This method is exception-free
346
347
348
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
349
350
            raise err

351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    @contextmanager
    def log_iteration_details(self, scheduler_output: SchedulerOutput):
        if not self.vllm_config.observability_config.enable_logging_iteration_details:
            yield
            return
        self._iteration_index = getattr(self, "_iteration_index", 0)
        iteration_details = compute_iteration_details(scheduler_output)
        before = time.monotonic()
        yield
        logger.info(
            "".join(
                [
                    "Iteration(",
                    str(self._iteration_index),
                    "): ",
                    str(iteration_details.num_ctx_requests),
                    " context requests, ",
                    str(iteration_details.num_ctx_tokens),
                    " context tokens, ",
                    str(iteration_details.num_generation_requests),
                    " generation requests, ",
                    str(iteration_details.num_generation_tokens),
                    " generation tokens, iteration elapsed time: ",
                    format((time.monotonic() - before) * 1000, ".2f"),
                    " ms",
                ]
            )
        )
        self._iteration_index += 1

381
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
382
383
384
385
386
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
387

388
389
390
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
391
            return {}, False
392
393
394
        scheduler_output = self.scheduler.schedule()
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
395
396
397
398
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
399
400
401
402
            model_output = future.result()
            if model_output is None:
                model_output = self.model_executor.sample_tokens(grammar_output)

403
404
405
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
406
407
408
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
409

410
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
411

412
    def post_step(self, model_executed: bool) -> None:
413
414
415
416
        # When using async scheduling we can't get draft token ids in advance,
        # so we update draft token ids in the worker process and don't
        # need to update draft token ids here.
        if not self.async_scheduling and self.use_spec_decode and model_executed:
417
418
419
420
421
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

422
    def step_with_batch_queue(
423
        self,
424
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
425
426
427
428
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
429
430
431
432
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
433
434
435
436
437
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
438

439
440
        batch_queue = self.batch_queue
        assert batch_queue is not None
441

442
443
444
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
445
        assert len(batch_queue) < self.batch_queue_size
446

447
        model_executed = False
448
        deferred_scheduler_output = None
449
        if self.scheduler.has_requests():
450
            scheduler_output = self.scheduler.schedule()
451
452
453
454
            with self.log_error_detail(scheduler_output):
                exec_future = self.model_executor.execute_model(
                    scheduler_output, non_block=True
                )
455
            if self.is_ec_consumer:
456
                model_executed = scheduler_output.total_num_scheduled_tokens > 0
457

458
            if self.is_pooling_model or not model_executed:
459
460
                # No sampling required (no requests scheduled).
                future = cast(Future[ModelRunnerOutput], exec_future)
461
            else:
462
463
464
                if not scheduler_output.pending_structured_output_tokens:
                    # We aren't waiting for any tokens, get any grammar output
                    # and sample immediately.
465
466
467
                    grammar_output = self.scheduler.get_grammar_bitmask(
                        scheduler_output
                    )
468
469
470
                    future = self.model_executor.sample_tokens(
                        grammar_output, non_block=True
                    )
471
                else:
472
473
474
475
476
                    # We need to defer sampling until we have processed the model output
                    # from the prior step.
                    deferred_scheduler_output = scheduler_output

            if not deferred_scheduler_output:
477
                # Add this step's future to the queue.
478
                batch_queue.appendleft((future, scheduler_output, exec_future))
479
480
481
482
483
484
485
486
                if (
                    model_executed
                    and len(batch_queue) < self.batch_queue_size
                    and not batch_queue[-1][0].done()
                ):
                    # Don't block on next worker response unless the queue is full
                    # or there are no more requests to schedule.
                    return None, True
487
488
489
490
491
492

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False
493
494

        # Block until the next result is available.
495
        future, scheduler_output, exec_model_fut = batch_queue.pop()
496
497
498
499
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
500
            model_output = future.result()
501
502
503
504
505
            if model_output is None:
                # None from sample_tokens() implies that the original execute_model()
                # call failed - raise that exception.
                exec_model_fut.result()
                raise RuntimeError("unexpected error")
506

507
508
509
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
510
511
512
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
513
514
515
516
517

        # NOTE(nick): We can either handle the deferred tasks here or save
        # in a field and do it immediately once step_with_batch_queue is
        # re-called. The latter slightly favors TTFT over TPOT/throughput.
        if deferred_scheduler_output:
518
519
520
521
522
523
524
525
526
527
528
529
            # If we are doing speculative decoding with structured output,
            # we need to get the draft token ids from the prior step before
            # we can compute the grammar bitmask for the deferred request.
            if self.use_spec_decode:
                draft_token_ids = self.model_executor.take_draft_token_ids()
                assert draft_token_ids is not None
                # Update the draft token ids in the scheduler output to
                # filter out the invalid spec tokens, which will be padded
                # with -1 and skipped by the grammar bitmask computation.
                self.scheduler.update_draft_token_ids_in_output(
                    draft_token_ids, deferred_scheduler_output
                )
530
531
532
533
534
535
            # We now have the tokens needed to compute the bitmask for the
            # deferred request. Get the bitmask and call sample tokens.
            grammar_output = self.scheduler.get_grammar_bitmask(
                deferred_scheduler_output
            )
            future = self.model_executor.sample_tokens(grammar_output, non_block=True)
536
            batch_queue.appendleft((future, deferred_scheduler_output, exec_future))
537

538
        return engine_core_outputs, model_executed
539

540
541
542
543
544
    def _process_aborts_queue(self):
        if not self.aborts_queue.empty():
            request_ids = []
            while not self.aborts_queue.empty():
                ids = self.aborts_queue.get_nowait()
545
546
                # Should be a list here, but also handle string just in case.
                request_ids.extend((ids,) if isinstance(ids, str) else ids)
547
548
549
            # More efficient to abort all as a single batch.
            self.abort_requests(request_ids)

550
    def shutdown(self):
551
        self.structured_output_manager.clear_backend()
552
553
        if self.model_executor:
            self.model_executor.shutdown()
554
555
        if self.scheduler:
            self.scheduler.shutdown()
556

557
558
    def profile(self, is_start: bool = True, profile_prefix: str | None = None):
        self.model_executor.profile(is_start, profile_prefix)
559

560
561
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
562
        # re-sync the internal caches (P0 sender, P1 receiver)
563
        if self.scheduler.has_unfinished_requests():
564
565
566
567
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
568

569
        # The cache either exists in EngineCore or WorkerWrapperBase
570
571
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
572

573
574
        self.model_executor.reset_mm_cache()

575
576
577
578
579
580
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.scheduler.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
581

582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
    def reset_encoder_cache(self) -> None:
        """Reset the encoder cache to invalidate all cached encoder outputs.

        This should be called when model weights are updated to ensure
        stale vision embeddings computed with old weights are not reused.
        Clears both the scheduler's cache manager and the GPU model runner's cache.
        """
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 sender, P1 receiver)
        if self.scheduler.has_unfinished_requests():
            logger.warning(
                "Resetting the encoder cache when requests are "
                "in progress may lead to desynced internal caches."
            )

        # Reset the scheduler's encoder cache manager (logical state)
        self.scheduler.reset_encoder_cache()
        # Reset the GPU model runner's encoder cache (physical storage)
        self.model_executor.reset_encoder_cache()

602
603
604
605
606
    def _reset_caches(self, reset_running_requests=True) -> None:
        self.reset_prefix_cache(reset_running_requests=reset_running_requests)
        self.reset_mm_cache()
        self.reset_encoder_cache()

607
608
    def pause_scheduler(
        self, mode: PauseMode = "abort", clear_cache: bool = True
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
    ) -> Future | None:
        """Pause generation; behavior depends on mode.

        All pause modes queue new adds -- "abort" and "keep" skip step();
        "wait" allows step() so in-flight requests can drain.

        - ``abort``: Set PAUSED_NEW, abort all requests, wait for abort
          outputs to be sent (when running with output_queue), optionally
          clear caches, then complete the returned Future.
        - ``wait``: Set PAUSED_NEW (queue adds, keep stepping); when drained,
          optionally clear caches, then complete the returned Future.
        - ``keep``: Set PAUSED_ALL; return a Future that completes when the
          output queue is empty.
        """
        if mode not in ("keep", "abort", "wait"):
            raise ValueError(f"Invalid pause mode: {mode}")
        if mode == "wait":
            raise ValueError("'wait' mode can't be used in inproc-engine mode")

        if mode == "abort":
            self.scheduler.finish_requests(None, RequestStatus.FINISHED_ABORTED)

        pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
        self.scheduler.set_pause_state(pause_state)
        if clear_cache:
            self._reset_caches()

636
637
638
        return None

    def resume_scheduler(self) -> None:
639
640
        """Resume the scheduler and flush any requests queued while paused."""
        self.scheduler.set_pause_state(PauseState.UNPAUSED)
641
642

    def is_scheduler_paused(self) -> bool:
643
644
        """Return whether the scheduler is in any pause state."""
        return self.scheduler.pause_state != PauseState.UNPAUSED
645

646
    def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None | Future:
647
648
649
650
651
652
653
654
        """Put the engine to sleep at the specified level.

        Args:
            level: Sleep level.
                - Level 0: Pause scheduling only. Requests are still accepted
                           but not processed. No GPU memory changes.
                - Level 1: Offload model weights to CPU, discard KV cache.
                - Level 2: Discard all GPU memory.
655
656
            mode: Pause mode - how to deal with any existing requests, see
                documentation of pause_scheduler method.
657
        """
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682

        # Pause scheduler before sleeping.
        clear_prefix_cache = level >= 1
        pause_future = self.pause_scheduler(mode=mode, clear_cache=clear_prefix_cache)
        if level < 1:
            return pause_future

        # Level 1+: Delegate to executor for GPU memory management
        model_executor = self.model_executor
        if pause_future is None:
            model_executor.sleep(level)
            return None

        future = Future[Any]()

        def pause_complete(f: Future):
            try:
                f.result()  # propagate any exception
                future.set_result(model_executor.sleep(level))
            except Exception as e:
                future.set_exception(e)

        logger.info("Waiting for in-flight requests to complete before sleeping...")
        pause_future.add_done_callback(pause_complete)
        return future
683

684
    def wake_up(self, tags: list[str] | None = None):
685
686
687
688
689
690
        """Wake up the engine from sleep.

        Args:
            tags: Tags to wake up. Use ["scheduling"] for level 0 wake up.
        """
        if tags is not None and "scheduling" in tags:
691
692
693
694
            # Remove "scheduling" from tags if there are other tags to process.
            tags = [t for t in tags if t != "scheduling"]

        if tags is None or tags:
695
            self.model_executor.wake_up(tags)
696

697
698
699
        # Resume scheduling (applies to all levels)
        self.resume_scheduler()

700
    def is_sleeping(self) -> bool:
701
        """Check if engine is sleeping at any level."""
702
        return self.is_scheduler_paused() or self.model_executor.is_sleeping
703

704
    def execute_dummy_batch(self):
705
        self.model_executor.execute_dummy_batch()
706

707
708
709
710
711
712
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

713
    def list_loras(self) -> set[int]:
714
715
716
717
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
718

719
720
721
    def save_sharded_state(
        self,
        path: str,
722
723
        pattern: str | None = None,
        max_size: int | None = None,
724
    ) -> None:
725
726
727
728
729
730
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
731
732
        method: str | Callable[..., _R],
        timeout: float | None = None,
733
        args: tuple = (),
734
        kwargs: dict[str, Any] | None = None,
735
736
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
737

738
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
739
        """Preprocess the request.
740

741
742
743
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
744
745
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
746
        # and will only be accessed in the input processing thread afterwards.
747
        if self.mm_receiver_cache is not None and request.mm_features:
748
749
750
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
751

752
        req = Request.from_engine_core_request(request, self.request_block_hasher)
753
754
755
756
757
758
759
760
761
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

762
763
764
765
766
767
768
769
770
771
    def _eep_scale_up_before_kv_init(self):
        raise NotImplementedError

    def _eep_send_engine_core_notification(
        self,
        notification_type: EEPNotificationType,
        vllm_config: VllmConfig | None = None,
    ):
        raise NotImplementedError

772

773
774
775
776
777
778
class EngineShutdownState(IntEnum):
    RUNNING = 0
    REQUESTED = 1
    SHUTTING_DOWN = 2


779
780
781
class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

782
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
783
    addresses: EngineZmqAddresses
784

785
    @instrument(span_name="EngineCoreProc init")
786
787
    def __init__(
        self,
788
        vllm_config: VllmConfig,
789
        local_client: bool,
790
        handshake_address: str,
791
        executor_class: type[Executor],
792
        log_stats: bool,
793
        client_handshake_address: str | None = None,
794
        tensor_queue: Queue | None = None,
795
        *,
796
        engine_index: int = 0,
797
    ):
Rui Qiao's avatar
Rui Qiao committed
798
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
799
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
800
        executor_fail_callback = lambda: self.input_queue.put_nowait(
801
802
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
803

Rui Qiao's avatar
Rui Qiao committed
804
805
806
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
807
        self.shutdown_state = EngineShutdownState.RUNNING
808

809
810
811
812
813
814
        # Receiver for tensor IPC
        self.tensor_ipc_receiver: TensorIpcReceiver | None = None
        if tensor_queue is not None:
            self.tensor_ipc_receiver = TensorIpcReceiver(tensor_queue)
            logger.info("Using tensor IPC queue for multimodal tensor sharing")

815
816
817
818
819
820
821
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
822
            # Set up data parallel environment.
823
            self.has_coordinator = addresses.coordinator_output is not None
824
            self.frontend_stats_publish_address = (
825
826
827
828
829
830
831
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
832
            internal_dp_balancing = (
833
                self.has_coordinator
834
835
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
836
837
838
            # Only publish request queue stats to coordinator for "internal"
            # and "hybrid" LB modes.
            self.publish_dp_lb_stats = internal_dp_balancing
839

840
841
842
843
844
845
846
            self.addresses = addresses
            self.process_input_queue_block = True
            if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
                self._eep_send_engine_core_notification(
                    EEPNotificationType.NEW_CORE_ENGINES_INIT_READY,
                    vllm_config=vllm_config,
                )
847
848
            self._init_data_parallel(vllm_config)

849
            super().__init__(
850
851
852
853
854
                vllm_config,
                executor_class,
                log_stats,
                executor_fail_callback,
                internal_dp_balancing,
855
            )
856

857
858
859
860
861
862
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
863
864
865
866
867
868
869
870
871
872
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
873
874
875
876
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
877
878
879
880
881
882
883
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
884
885
886
887
888
889
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
890
                    raise RuntimeError("Input socket thread died during startup")
891
892
893
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

Rui Qiao's avatar
Rui Qiao committed
894
    @contextmanager
895
896
897
898
899
900
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
901
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
902
    ) -> Generator[EngineZmqAddresses, None, None]:
903
904
905
906
907
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

908
        For DP>1 with internal load-balancing this is with the shared front-end
909
910
        process which may reside on a different node.

911
        For DP>1 with external or hybrid load-balancing, two handshakes are
912
        performed:
913
914
915
916
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
917
918
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
919
920
921
922
923
924

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
925
        input_ctx = zmq.Context()
926
        is_local = local_client and client_handshake_address is None
927
        headless = not local_client
928
929
930
931
932
933
934
935
936
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
937
938
939
940
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
941
            assert local_client
942
            local_handshake = self._perform_handshake(
943
944
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
960
        headless: bool,
961
        vllm_config: VllmConfig,
962
        parallel_config_to_update: ParallelConfig | None = None,
963
    ) -> Generator[EngineZmqAddresses, None, None]:
964
965
966
967
968
969
970
971
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
972
            # Register engine with front-end.
973
974
975
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
                )

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
999

1000
    @staticmethod
1001
    def startup_handshake(
1002
1003
        handshake_socket: zmq.Socket,
        local_client: bool,
1004
        headless: bool,
1005
        parallel_config: ParallelConfig | None = None,
1006
    ) -> EngineZmqAddresses:
1007
        # Send registration message.
1008
        handshake_socket.send(
1009
1010
1011
1012
1013
1014
1015
1016
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
1017
1018

        # Receive initialization message.
1019
        logger.debug("Waiting for init message from front-end.")
1020
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
1021
1022
1023
1024
1025
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
1026
1027
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
1028
1029
            init_bytes, type=EngineHandshakeMetadata
        )
1030
1031
        logger.debug("Received init message: %s", init_message)

1032
1033
1034
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
1035

1036
        return init_message.addresses
1037
1038

    @staticmethod
1039
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
1040
1041
        """Launch EngineCore busy loop in background process."""

1042
1043
1044
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

1045
        engine_core: EngineCoreProc | None = None
1046
        signal_callback: SignalCallback | None = None
1047
        try:
1048
1049
1050
1051
1052
            vllm_config: VllmConfig = kwargs["vllm_config"]
            parallel_config: ParallelConfig = vllm_config.parallel_config
            data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
            if data_parallel:
                parallel_config.data_parallel_rank_local = local_dp_rank
1053
                process_title = f"EngineCore_DP{dp_rank}"
1054
            else:
1055
1056
1057
                process_title = "EngineCore"
            set_process_title(process_title)
            maybe_init_worker_tracer("vllm.engine_core", "engine_core", process_title)
1058
            decorate_logs()
1059
1060
            if parallel_config.numa_bind:
                numa_utils.log_current_affinity_state(process_title)
1061

1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
            if data_parallel and vllm_config.kv_transfer_config is not None:
                # modify the engine_id and append the local_dp_rank to it to ensure
                # that the kv_transfer_config is unique for each DP rank.
                vllm_config.kv_transfer_config.engine_id = (
                    f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
                )
                logger.debug(
                    "Setting kv_transfer_config.engine_id to %s",
                    vllm_config.kv_transfer_config.engine_id,
                )

1073
1074
            parallel_config.data_parallel_index = dp_rank
            if data_parallel and vllm_config.model_config.is_moe:
1075
1076
1077
1078
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
1079
1080
1081
1082
1083
1084
1085
                # Non-MoE DP ranks are completely independent, so treat like DP=1.
                # Note that parallel_config.data_parallel_index will still reflect
                # the original DP rank.
                parallel_config.data_parallel_size = 1
                parallel_config.data_parallel_size_local = 1
                parallel_config.data_parallel_rank = 0
                engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
1086

1087
            assert engine_core is not None
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103

            def wakeup_engine():
                # Wakes up idle engine via input_queue when shutdown is requested
                # Not safe in a signal handler - we may interrupt the main thread
                # while it is holding the non-reentrant input_queue.mutex
                engine_core.input_queue.put_nowait((EngineCoreRequestType.WAKEUP, None))

            signal_callback = SignalCallback(wakeup_engine)

            def signal_handler(signum, frame):
                engine_core.shutdown_state = EngineShutdownState.REQUESTED
                signal_callback.trigger()

            signal.signal(signal.SIGTERM, signal_handler)
            signal.signal(signal.SIGINT, signal_handler)

1104
1105
            engine_core.run_busy_loop()

1106
        except SystemExit:
1107
            logger.debug("EngineCore exiting.")
1108
            raise
1109
1110
1111
1112
1113
1114
1115
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
1116
        finally:
1117
1118
1119
1120
            signal.signal(signal.SIGTERM, signal.SIG_DFL)
            signal.signal(signal.SIGINT, signal.SIG_DFL)
            if signal_callback is not None:
                signal_callback.stop()
1121
1122
1123
            if engine_core is not None:
                engine_core.shutdown()

1124
1125
1126
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

1127
1128
1129
1130
1131
1132
1133
1134
    def has_work(self) -> bool:
        """Returns true if the engine should be stepped."""
        return (
            self.engines_running
            or self.scheduler.has_requests()
            or bool(self.batch_queue)
        )

1135
1136
1137
1138
    def is_running(self) -> bool:
        """Returns true if shutdown has not been requested."""
        return self.shutdown_state == EngineShutdownState.RUNNING

1139
1140
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""
1141
        while self._handle_shutdown():
1142
            # 1) Poll the input queue until there is work to do.
1143
1144
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
1145
            self._process_engine_step()
1146

1147
1148
        raise SystemExit

1149
1150
1151
1152
    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
1153
        while not self.has_work() and self.is_running():
1154
1155
            # Notify callbacks waiting for engine to become idle.
            self._notify_idle_state_callbacks()
1156
1157
1158
1159
1160
1161
1162
            if self.input_queue.empty():
                # Drain aborts queue; all aborts are also processed via input_queue.
                with self.aborts_queue.mutex:
                    self.aborts_queue.queue.clear()
                if logger.isEnabledFor(DEBUG):
                    logger.debug("EngineCore waiting for work.")
                    waited = True
1163
1164
1165
1166
1167
1168
1169
1170
            block = self.process_input_queue_block
            try:
                req = self.input_queue.get(block=block)
                self._handle_client_request(*req)
            except queue.Empty:
                break
            if not block:
                break
1171
1172

        if waited:
1173
            logger.debug("EngineCore loop active.")
1174
1175
1176
1177
1178
1179

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

1180
    def _process_engine_step(self) -> bool:
1181
1182
1183
        """Called only when there are unfinished local requests."""

        # Step the engine core.
1184
        outputs, model_executed = self.step_fn()
1185
        # Put EngineCoreOutputs into the output queue.
1186
        for output in outputs.items() if outputs else ():
1187
            self.output_queue.put_nowait(output)
1188
1189
        # Post-step hook.
        self.post_step(model_executed)
1190

1191
1192
1193
1194
1195
1196
1197
        # If no model execution happened but there are waiting requests
        # (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow
        # background threads (like NIXL handshake) to make progress.
        # Without this, the tight polling loop can starve background threads.
        if not model_executed and self.scheduler.has_unfinished_requests():
            time.sleep(0.001)

1198
1199
        return model_executed

1200
1201
1202
1203
    def _notify_idle_state_callbacks(self) -> None:
        while self._idle_state_callbacks:
            callback = self._idle_state_callbacks.pop()
            callback(self)
1204

1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
    def _handle_shutdown(self) -> bool:
        # Check if shutdown was requested and handle it
        if self.shutdown_state == EngineShutdownState.RUNNING:
            return True

        if self.shutdown_state == EngineShutdownState.REQUESTED:
            shutdown_timeout = self.vllm_config.shutdown_timeout

            logger.info("Shutdown initiated (timeout=%d)", shutdown_timeout)

            if shutdown_timeout == 0:
                num_requests = self.scheduler.get_num_unfinished_requests()
                if num_requests > 0:
                    logger.info("Aborting %d requests", num_requests)
                aborted_reqs = self.scheduler.finish_requests(
                    None, RequestStatus.FINISHED_ABORTED
                )
                self._send_abort_outputs(aborted_reqs)
            else:
                num_requests = self.scheduler.get_num_unfinished_requests()
                if num_requests > 0:
                    logger.info(
                        "Draining %d in-flight requests (timeout=%ds)",
                        num_requests,
                        shutdown_timeout,
                    )

            self.shutdown_state = EngineShutdownState.SHUTTING_DOWN

        # Exit when no work remaining
        if not self.has_work():
            logger.info("Shutdown complete")
            return False

        return True

1241
1242
1243
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1244
        """Dispatch request from client."""
1245

1246
1247
1248
        if request_type == EngineCoreRequestType.WAKEUP:
            return
        elif request_type == EngineCoreRequestType.ADD:
1249
            req, request_wave = request
1250
1251
            if self._reject_add_in_shutdown(req):
                return
1252
            self.add_request(req, request_wave)
1253
        elif request_type == EngineCoreRequestType.ABORT:
1254
            self.abort_requests(request)
1255
        elif request_type == EngineCoreRequestType.UTILITY:
1256
            client_idx, call_id, method_name, args = request
1257
1258
            if self._reject_utility_in_shutdown(client_idx, call_id, method_name):
                return
1259
            output = UtilityOutput(call_id)
1260
            # Lazily look-up utility method so that failure will be handled/returned.
1261
1262
1263
            get_result = lambda: (
                (method := getattr(self, method_name))
                and method(*self._convert_msgspec_args(method, args))
1264
1265
1266
            )
            enqueue_output = lambda out: self.output_queue.put_nowait(
                (client_idx, EngineCoreOutputs(utility_output=out))
1267
            )
1268
            self._invoke_utility_method(method_name, get_result, output, enqueue_output)
1269
1270
1271
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
1272
1273
1274
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
1275

1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
    def _reject_add_in_shutdown(self, request: Request) -> bool:
        if self.shutdown_state == EngineShutdownState.RUNNING:
            return False

        logger.info("Rejecting request %s (server shutting down)", request.request_id)
        self._send_abort_outputs_to_client([request.request_id], request.client_index)
        return True

    def _reject_utility_in_shutdown(
        self, client_idx: int, call_id: int, method_name: str
    ) -> bool:
        if self.shutdown_state == EngineShutdownState.RUNNING:
            return False

        logger.warning("Rejecting utility call %s (server shutting down)", method_name)
        output = UtilityOutput(call_id, failure_message="Server shutting down")
        self.output_queue.put_nowait(
            (client_idx, EngineCoreOutputs(utility_output=output))
        )
        return True

1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
    @staticmethod
    def _invoke_utility_method(
        name: str, get_result: Callable, output: UtilityOutput, enqueue_output: Callable
    ):
        try:
            result = get_result()
            if isinstance(result, Future):
                # Defer utility output handling until future completion.
                callback = lambda future: EngineCoreProc._invoke_utility_method(
                    name, future.result, output, enqueue_output
                )
                result.add_done_callback(callback)
                return
            output.result = UtilityResult(result)
        except Exception as e:
            logger.exception("Invocation of %s method failed", name)
            output.failure_message = f"Call to {name} method failed: {str(e)}"
        enqueue_output(output)

1316
1317
1318
    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
1319
        arg type, try converting to msgspec object."""
1320
1321
1322
1323
1324
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
1325
1326
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
1327
            and issubclass(p.annotation, msgspec.Struct)
1328
1329
1330
1331
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
1332

1333
1334
1335
1336
1337
1338
1339
1340
1341
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
1342
1343
1344
1345
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
1346

1347
1348
1349
    def process_input_sockets(
        self,
        input_addresses: list[str],
1350
        coord_input_address: str | None,
1351
1352
1353
        identity: bytes,
        ready_event: threading.Event,
    ):
1354
1355
        """Input socket IO thread."""

1356
1357
1358
1359
1360
        # Msgpack serialization decoding with optional tensor IPC receiver.
        add_request_decoder = MsgpackDecoder(
            EngineCoreRequest, oob_tensor_provider=self.tensor_ipc_receiver
        )
        generic_decoder = MsgpackDecoder(oob_tensor_provider=self.tensor_ipc_receiver)
1361

1362
1363
1364
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
1365
1366
1367
1368
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
1369
1370
1371
1372
1373
1374
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
1375
1376
1377
1378
1379
1380
1381
1382
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
1383
                # Send subscription message to coordinator.
1384
                coord_socket.send(b"\x01")
1385
1386
1387
1388
1389
1390
1391

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
1392
                input_socket.send(b"")
1393
                poller.register(input_socket, zmq.POLLIN)
1394

1395
            if coord_socket is not None:
1396
1397
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
1398
                poller.register(coord_socket, zmq.POLLIN)
1399

1400
1401
            ready_event.set()
            del ready_event
1402
1403
1404
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
1405
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
1406
1407
1408
1409
1410
                    # NOTE(yongji): ignore READY message sent by DP coordinator
                    # that is used to notify newly started engines
                    if type_frame.buffer == b"READY":
                        assert input_socket == coord_socket
                        continue
1411
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
1412
1413

                    # Deserialize the request data.
1414
                    request: Any
1415
                    if request_type == EngineCoreRequestType.ADD:
1416
1417
1418
1419
1420
1421
                        req: EngineCoreRequest = add_request_decoder.decode(data_frames)
                        try:
                            request = self.preprocess_add_request(req)
                        except Exception:
                            self._handle_request_preproc_error(req)
                            continue
1422
1423
                    else:
                        request = generic_decoder.decode(data_frames)
1424

1425
1426
1427
1428
1429
1430
1431
                        if request_type == EngineCoreRequestType.ABORT:
                            # Aborts are added to *both* queues, allows us to eagerly
                            # process aborts while also ensuring ordering in the input
                            # queue to avoid leaking requests. This is ok because
                            # aborting in the scheduler is idempotent.
                            self.aborts_queue.put_nowait(request)

1432
1433
1434
                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

1435
    def process_output_sockets(
1436
        self, output_paths: list[str], coord_output_path: str | None, engine_index: int
1437
    ):
1438
1439
1440
        """Output socket IO thread."""

        # Msgpack serialization encoding.
1441
        encoder = MsgpackEncoder()
1442
1443
1444
1445
1446
1447
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
1448

1449
1450
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1451
1452
1453
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1454
1455
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1456
1457
                for output_path in output_paths
            ]
1458
1459
1460
1461
1462
1463
1464
1465
1466
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1467
1468
            max_reuse_bufs = len(sockets) + 1

1469
            while True:
1470
1471
1472
1473
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1474
                    break
1475
1476
                assert not isinstance(output, bytes)
                client_index, outputs = output
1477
                outputs.engine_index = engine_index
1478

1479
1480
1481
1482
1483
1484
1485
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1486
1487
1488
1489
1490
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1491
                buffers = encoder.encode_into(outputs, buffer)
1492
1493
1494
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1495
1496
1497
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1498
1499
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1500
                    reuse_buffers.append(buffer)
1501

1502
1503
1504
1505
1506
1507
1508
    def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None:
        """Log and return a request-scoped error response for exceptions raised
        from the add request preprocessing in the input socket processing thread.
        """
        logger.exception(
            "Unexpected error pre-processing request %s", request.request_id
        )
1509
        self._send_error_outputs_to_client([request.request_id], request.client_index)
1510

1511
1512
1513
1514
1515
    def pause_scheduler(
        self, mode: PauseMode = "abort", clear_cache: bool = True
    ) -> Future | None:
        """Pause generation; behavior depends on mode.

1516
1517
1518
1519
1520
1521
1522
1523
1524
        All pause modes queue new adds -- "abort" and "keep" skip step();
        "wait" allows step() so in-flight requests can drain.

        - ``abort``: Set PAUSED_NEW, abort all requests, wait for abort
          outputs to be sent (when running with output_queue), optionally
          clear caches, then complete the returned Future.
        - ``wait``: Set PAUSED_NEW (queue adds, keep stepping); when drained,
          optionally clear caches, then complete the returned Future.
        - ``keep``: Set PAUSED_ALL; return a Future that completes when the
1525
1526
1527
1528
1529
          output queue is empty.
        """
        if mode not in ("keep", "abort", "wait"):
            raise ValueError(f"Invalid pause mode: {mode}")

1530
        def engine_idle_callback(engine: "EngineCoreProc", future: Future[Any]) -> None:
1531
            if clear_cache:
1532
                engine._reset_caches()
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
            future.set_result(None)

        if mode == "abort":
            aborted_reqs = self.scheduler.finish_requests(
                None, RequestStatus.FINISHED_ABORTED
            )
            self._send_abort_outputs(aborted_reqs)

        pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
        self.scheduler.set_pause_state(pause_state)
1543
1544
1545
1546
1547
1548
1549
1550
        if not self.has_work():
            if clear_cache:
                self._reset_caches()
            return None

        future = Future[Any]()
        self._idle_state_callbacks.append(partial(engine_idle_callback, future=future))
        return future
1551

1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
    def _send_finish_outputs_to_client(
        self, req_ids: list[str], client_index: int, finish_reason: FinishReason
    ) -> None:
        outputs = [
            EngineCoreOutput(req_id, [], finish_reason=finish_reason)
            for req_id in req_ids
        ]
        eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
        self.output_queue.put_nowait((client_index, eco))

    def _send_abort_outputs_to_client(
        self, req_ids: list[str], client_index: int
    ) -> None:
        self._send_finish_outputs_to_client(req_ids, client_index, FinishReason.ABORT)

    def _send_error_outputs_to_client(
        self, req_ids: list[str], client_index: int
    ) -> None:
        self._send_finish_outputs_to_client(req_ids, client_index, FinishReason.ERROR)

1572
    def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None:
1573
        # TODO(nick) this will be moved inside the scheduler
1574
1575
1576
1577
1578
1579
        if aborted_reqs:
            # Map client_index to list of request_ids that belong to that client.
            by_client = defaultdict[int, set[str]](set)
            for req_id, client_index in aborted_reqs:
                by_client[client_index].add(req_id)
            for client_index, req_ids in by_client.items():
1580
                self._send_abort_outputs_to_client(list(req_ids), client_index)
1581

1582
1583
1584
1585
1586
1587
1588
1589

class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1590
        local_client: bool,
1591
        handshake_address: str,
1592
1593
        executor_class: type[Executor],
        log_stats: bool,
1594
        client_handshake_address: str | None = None,
1595
        tensor_queue: Queue | None = None,
1596
    ):
1597
1598
1599
1600
        assert vllm_config.model_config.is_moe, (
            "DPEngineCoreProc should only be used for MoE models"
        )

1601
1602
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1603
        self.step_counter = 0
1604
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1605
        self.last_counts = (0, 0)
1606

1607
1608
1609
1610
        from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState

        self.eep_scaling_state: ElasticEPScalingState | None = None

1611
1612
        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1613
1614
1615
1616
1617
1618
1619
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
1620
            engine_index=dp_rank,
1621
            tensor_queue=tensor_queue,
1622
        )
1623
1624
1625

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1626
1627
1628
1629
        parallel_config = vllm_config.parallel_config
        dp_rank = parallel_config.data_parallel_rank
        dp_size = parallel_config.data_parallel_size
        local_dp_rank = parallel_config.data_parallel_rank_local
1630
1631

        assert dp_size > 1
1632
        assert local_dp_rank is not None
1633
1634
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1635
        self.dp_rank = dp_rank
1636
1637
        dp_group, dp_store = parallel_config.stateless_init_dp_group(return_store=True)
        self.dp_group, self.dp_store = dp_group, dp_store
1638
1639
1640
1641
1642
1643

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1644
    def add_request(self, request: Request, request_wave: int = 0):
1645
        super().add_request(request, request_wave)
1646
1647
1648
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1649
1650
1651
1652
1653
            elif (
                not self.engines_running
                and self.scheduler.pause_state == PauseState.UNPAUSED
            ):
                self.engines_running = True
1654
1655
1656
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1657
1658
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1659

1660
1661
    def resume_scheduler(self):
        super().resume_scheduler()
1662
1663
1664
1665
1666
        if (
            self.has_coordinator
            and not self.engines_running
            and self.scheduler.has_unfinished_requests()
        ):
1667
1668
1669
1670
            # Wake up other DP engines.
            self.output_queue.put_nowait(
                (-1, EngineCoreOutputs(start_wave=self.current_wave))
            )
1671

1672
1673
1674
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1675
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1676
1677
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1678
1679
                new_wave >= self.current_wave
            ):
1680
1681
                self.current_wave = new_wave
                if not self.engines_running:
1682
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1683
1684
1685
1686
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1687
    def _maybe_publish_request_counts(self):
1688
        if not self.publish_dp_lb_stats:
1689
1690
1691
1692
1693
1694
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1695
1696
1697
1698
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1699

1700
1701
1702
1703
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
1704
        while self._handle_shutdown():
1705
1706
1707
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1708
1709
1710
            if self.eep_scaling_state is not None:
                _ = self.eep_scaling_state.progress()
                if self.eep_scaling_state.is_complete():
1711
1712
                    if self.eep_scaling_state.worker_type == "removing":
                        raise SystemExit
1713
1714
1715
                    self.process_input_queue_block = True
                    self.eep_scaling_state = None

1716
            executed = self._process_engine_step()
1717
            self._maybe_publish_request_counts()
1718

1719
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1720
1721
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1722
1723
1724
                    # All engines are idle.
                    continue

1725
1726
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1727
1728
1729
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1730
            self.engines_running = self._has_global_unfinished_reqs(
1731
1732
                local_unfinished_reqs
            )
1733

1734
            if not self.engines_running:
1735
                if self.dp_rank == 0 or not self.has_coordinator:
1736
                    # Notify client that we are pausing the loop.
1737
1738
1739
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1740
1741
1742
1743
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1744
                    self.output_queue.put_nowait(
1745
1746
1747
1748
1749
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1750
                # Increment wave count and reset step counter.
1751
                self.current_wave += 1
1752
                self.step_counter = 0
1753

1754
1755
        raise SystemExit

1756
    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1757
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1758
1759
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1760
1761
            return True

1762
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1763

1764
    def reinitialize_distributed(
1765
1766
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
        from copy import deepcopy

        from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState

        new_parallel_config = deepcopy(self.vllm_config.parallel_config)
        old_dp_size = new_parallel_config.data_parallel_size
        new_parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            new_parallel_config.data_parallel_rank = (
                reconfig_request.new_data_parallel_rank
            )
        new_parallel_config.data_parallel_master_ip = (
1782
            reconfig_request.new_data_parallel_master_ip
1783
        )
1784
        new_parallel_config.data_parallel_master_port = (
1785
            reconfig_request.new_data_parallel_master_port
1786
        )
1787
1788
        new_parallel_config._data_parallel_master_port_list = (
            reconfig_request.new_data_parallel_master_port_list
1789
        )
1790
        new_parallel_config._coord_store_port = reconfig_request.coord_store_port
1791

1792
1793
        is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size
        is_shutdown = (
1794
1795
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
        )

        self.eep_scaling_state = ElasticEPScalingState(
            model_executor=self.model_executor,
            engine_core=self,
            vllm_config=self.vllm_config,
            new_parallel_config=new_parallel_config,
            worker_type="removing" if is_shutdown else "existing",
            scale_type="scale_down" if is_scale_down else "scale_up",
            reconfig_request=reconfig_request,
        )
        self.process_input_queue_block = False
        logger.info(
            "[Elastic EP] Received reconfiguration request and starting scaling up/down"
        )

    def _eep_send_engine_core_notification(
        self,
        notification_type: EEPNotificationType,
        vllm_config: VllmConfig | None = None,
    ):
        """
        Send notifications to EngineCoreClient, which can then forward
        the notifications to other engine core processes. It is used for:
Jiayi Yan's avatar
Jiayi Yan committed
1820
        1) In scale up: new core engines to notify existing core engines
1821
1822
1823
           that they are ready;
        2) In scale down: removing core engines to notify EngineCoreClient
           so EngineCoreClient can release their ray placement groups;
Jiayi Yan's avatar
Jiayi Yan committed
1824
        3) Both scale up/down: to notify EngineCoreClient that existing
1825
1826
1827
1828
           core engines have already switched to the new parallel setup.
        """
        if vllm_config is None:
            dp_rank = self.vllm_config.parallel_config.data_parallel_rank
1829
        else:
1830
1831
1832
1833
1834
1835
            dp_rank = vllm_config.parallel_config.data_parallel_rank
        notification_data = (notification_type.value, dp_rank)
        outputs = EngineCoreOutputs(
            utility_output=UtilityOutput(
                call_id=EEP_NOTIFICATION_CALL_ID,
                result=UtilityResult(notification_data),
1836
            )
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
        )
        outputs.engine_index = self.engine_index

        if hasattr(self, "output_thread") and self.output_thread.is_alive():
            self.output_queue.put_nowait((0, outputs))
        else:
            encoder = MsgpackEncoder()
            with (
                zmq.Context() as ctx,
                make_zmq_socket(
                    ctx, self.addresses.outputs[0], zmq.PUSH, linger=4000
                ) as socket,
            ):
                socket.send_multipart(encoder.encode(outputs))

    def eep_handle_engine_core_notification(
        self, notification_type: str | EEPNotificationType
    ):
        """
        Handle notification received from EngineCoreClient
        (forwarded from new core engines).
        """
        assert self.eep_scaling_state is not None
        if isinstance(notification_type, str):
            notification_type = EEPNotificationType(notification_type)
        self.eep_scaling_state.handle_notification(notification_type)

    def _eep_scale_up_before_kv_init(self):
        from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState

        self.eep_scaling_state = ElasticEPScalingState(
            model_executor=self.model_executor,
            engine_core=self,
            vllm_config=self.vllm_config,
            new_parallel_config=self.vllm_config.parallel_config,
            worker_type="new",
            scale_type="scale_up",
            reconfig_request=None,
        )
1876
        self.eep_scaling_state.run_pre_kv_init_states()
1877
        self.process_input_queue_block = False
1878

Rui Qiao's avatar
Rui Qiao committed
1879

1880
class EngineCoreActorMixin:
Rui Qiao's avatar
Rui Qiao committed
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        addresses: EngineZmqAddresses,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
1892
1893
1894
1895
1896
1897
1898
        # Initialize tracer for distributed tracing if configured.
        maybe_init_worker_tracer(
            instrumenting_module_name="vllm.engine_core",
            process_kind="engine_core",
            process_name=f"DPEngineCoreActor_DP{dp_rank}",
        )

Rui Qiao's avatar
Rui Qiao committed
1899
        self.addresses = addresses
1900
        vllm_config.parallel_config.data_parallel_index = dp_rank
1901
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1902

1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1913
1914
1915
1916
1917
1918
1919
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1920
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1921

1922
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1923
        from vllm.platforms import current_platform
1924

1925
1926
1927
1928
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1929
1930
1931
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1932

1933
1934
1935
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1936
1937
1938
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1939
1940
1941
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1942
            os.environ[device_control_env_var] = value
1943
1944
1945
1946
1947
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1948
1949
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1950

Rui Qiao's avatar
Rui Qiao committed
1951
    @contextmanager
1952
1953
1954
1955
1956
1957
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1958
        client_handshake_address: str | None,
1959
    ):
Rui Qiao's avatar
Rui Qiao committed
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
1982
            self.run_busy_loop()  # type: ignore[attr-defined]
Rui Qiao's avatar
Rui Qiao committed
1983
1984
1985
1986
1987
1988
1989
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
            self.shutdown()  # type: ignore[attr-defined]


class DPMoEEngineCoreActor(EngineCoreActorMixin, DPEngineCoreProc):
    """Used for MoE model data parallel cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_rank = dp_rank

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        DPEngineCoreProc.__init__(
            self, vllm_config, local_client, "", executor_class, log_stats
        )


class EngineCoreActor(EngineCoreActorMixin, EngineCoreProc):
    """Used for non-MoE and/or non-DP cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_size = 1
        vllm_config.parallel_config.data_parallel_size_local = 1
        vllm_config.parallel_config.data_parallel_rank = 0

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        EngineCoreProc.__init__(
            self,
            vllm_config,
            local_client,
            "",
            executor_class,
            log_stats,
            engine_index=dp_rank,
        )