core.py 78.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import contextlib
4
import os
5
import queue
6
import signal
7
8
import threading
import time
9
from collections import defaultdict, deque
10
from collections.abc import Callable, Generator
11
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
12
from contextlib import ExitStack, contextmanager
13
from functools import partial
14
from inspect import isclass, signature
15
from logging import DEBUG
16
from typing import Any, TypeVar, cast
17

18
import msgspec
19
20
import zmq

21
import vllm.envs as envs
22
23
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
24
from vllm.envs import enable_envs_cache
25
from vllm.logger import init_logger
26
from vllm.logging_utils.dump_input import dump_engine_exception
27
from vllm.lora.request import LoRARequest
28
from vllm.multimodal import MULTIMODAL_REGISTRY
29
from vllm.tasks import POOLING_TASKS, SupportedTask
30
from vllm.tracing import instrument, maybe_init_worker_tracer
31
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
32
33
34
35
from vllm.utils.gc_utils import (
    freeze_gc_heap,
    maybe_attach_gc_debug_callback,
)
36
from vllm.utils.hashing import get_hash_fn_by_name
37
from vllm.utils.network_utils import make_zmq_socket
38
from vllm.utils.system_utils import decorate_logs, set_process_title
39
40
41
42
43
44
45
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
46
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
47
from vllm.v1.core.sched.output import SchedulerOutput
48
from vllm.v1.engine import (
49
50
    EEP_NOTIFICATION_CALL_ID,
    EEPNotificationType,
51
    EngineCoreOutput,
52
53
54
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
55
    FinishReason,
56
    PauseMode,
57
58
59
60
61
62
63
64
65
66
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
67
from vllm.v1.executor import Executor
68
from vllm.v1.kv_cache_interface import KVCacheConfig
69
from vllm.v1.metrics.stats import SchedulerStats
70
from vllm.v1.outputs import ModelRunnerOutput
71
from vllm.v1.request import Request, RequestStatus
72
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
73
from vllm.v1.structured_output import StructuredOutputManager
74
from vllm.v1.utils import compute_iteration_details
75
76
77
78
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

79
HANDSHAKE_TIMEOUT_MINS = 5
80

81
_R = TypeVar("_R")  # Return type for collective_rpc
82

83
84
85
86

class EngineCore:
    """Inner loop of vLLM's Engine."""

87
88
89
90
91
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
92
        executor_fail_callback: Callable | None = None,
93
        include_finished_set: bool = False,
94
    ):
95
96
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
97

98
99
        load_general_plugins()

100
        self.vllm_config = vllm_config
101
        if not vllm_config.parallel_config.data_parallel_rank_local:
102
103
104
105
106
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
107

108
109
        self.log_stats = log_stats

110
111
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
112
        if executor_fail_callback is not None:
113
            self.model_executor.register_failure_callback(executor_fail_callback)
114

115
116
        self.available_gpu_memory_for_kv_cache = -1

117
118
119
        if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
            self._eep_scale_up_before_kv_init()

120
        # Setup KV Caches and update CacheConfig after profiling.
121
122
123
124
125
126
127
128
129
130
131
        try:
            num_gpu_blocks, num_cpu_blocks, kv_cache_config = (
                self._initialize_kv_caches(vllm_config)
            )
        except Exception:
            logger.exception(
                "EngineCore failed during KV cache initialization; "
                "shutting down executor."
            )
            self.model_executor.shutdown()
            raise
132

133
134
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
135
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
136

137
138
        self.structured_output_manager = StructuredOutputManager(vllm_config)

139
        # Setup scheduler.
140
        Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
141

142
        if len(kv_cache_config.kv_cache_groups) == 0:  # noqa: SIM102
143
144
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
145
146
147
            if vllm_config.scheduler_config.enable_chunked_prefill:
                logger.warning("Disabling chunked prefill for model without KVCache")
                vllm_config.scheduler_config.enable_chunked_prefill = False
148

149
150
151
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
152
            * vllm_config.parallel_config.prefill_context_parallel_size
153
154
        )

155
        self.scheduler: SchedulerInterface = Scheduler(
156
            vllm_config=vllm_config,
157
158
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
159
            include_finished_set=include_finished_set,
160
            log_stats=self.log_stats,
161
            block_size=scheduler_block_size,
162
        )
163
        self.use_spec_decode = vllm_config.speculative_config is not None
164
        if self.scheduler.connector is not None:  # type: ignore
165
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
166

167
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
168
169
        self.mm_receiver_cache = mm_registry.engine_receiver_cache_from_config(
            vllm_config
170
        )
171

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        # If a KV connector is initialized for scheduler, we want to collect
        # handshake metadata from all workers so the connector in the scheduler
        # will have the full context
        kv_connector = self.scheduler.get_kv_connector()
        if kv_connector is not None:
            # Collect and store KV connector xfer metadata from workers
            # (after KV cache registration)
            xfer_handshake_metadata = (
                self.model_executor.get_kv_connector_handshake_metadata()
            )

            if xfer_handshake_metadata:
                # xfer_handshake_metadata is list of dicts from workers
                # Each dict already has structure {tp_rank: metadata}
                # Merge all worker dicts into a single dict
                content: dict[int, Any] = {}
                for worker_dict in xfer_handshake_metadata:
                    if worker_dict is not None:
                        content.update(worker_dict)
                kv_connector.set_xfer_handshake_metadata(content)

193
194
195
196
197
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
198
        self.batch_queue: (
199
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput, Future[Any]]] | None
200
        ) = None
201
        if self.batch_queue_size > 1:
202
            logger.debug("Batch queue is enabled with size %d", self.batch_queue_size)
203
            self.batch_queue = deque(maxlen=self.batch_queue_size)
204

205
206
207
        self.is_ec_consumer = (
            vllm_config.ec_transfer_config is None
            or vllm_config.ec_transfer_config.is_ec_consumer
208
        )
209
        self.is_pooling_model = vllm_config.model_config.runner_type == "pooling"
210

211
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
212
        if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
213
            caching_hash_fn = get_hash_fn_by_name(
214
215
                vllm_config.cache_config.prefix_caching_hash_algo
            )
216
217
218
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
219
                scheduler_block_size, caching_hash_fn
220
            )
221

222
223
224
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
225
        self.async_scheduling = vllm_config.scheduler_config.async_scheduling
226

227
        self.aborts_queue = queue.Queue[list[str]]()
228

229
        self._idle_state_callbacks: list[Callable] = []
230

231
232
233
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        freeze_gc_heap()
234
235
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()
236
237
238
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()
239

240
    @instrument(span_name="Prepare model")
241
    def _initialize_kv_caches(
242
243
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
244
        start = time.time()
245

246
        # Get all kv cache needed by the model
247
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
248

249
250
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
251
252
253
254
            if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
                # NOTE(yongji): should already be set
                # during _eep_scale_up_before_kv_init
                assert self.available_gpu_memory_for_kv_cache > 0
255
256
257
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
258
259
260
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
261
262
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
263
264
265
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
266

267
        assert len(kv_cache_specs) == len(available_gpu_memory)
268

269
270
271
        # Track max_model_len before KV cache config to detect auto-fit changes
        max_model_len_before = vllm_config.model_config.max_model_len

272
273
274
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
275
276
277
278
279
280
281
282

        # If auto-fit reduced max_model_len, sync the new value to workers.
        # This is needed because workers were spawned before memory profiling
        # and have the original (larger) max_model_len cached.
        max_model_len_after = vllm_config.model_config.max_model_len
        if max_model_len_after != max_model_len_before:
            self.collective_rpc("update_max_model_len", args=(max_model_len_after,))

283
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
284
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
285
        num_cpu_blocks = 0
286
287

        # Initialize kv cache and warmup the execution
288
        self.model_executor.initialize_from_config(kv_cache_configs)
289

290
        elapsed = time.time() - start
291
        logger.info_once(
292
            "init engine (profile, create kv cache, warmup model) took %.2f seconds",
293
            elapsed,
294
            scope="local",
295
        )
296
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
297

298
299
300
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

301
302
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
303

304
305
306
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
307
308
309
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
310
311
                f"request_id must be a string, got {type(request.request_id)}"
            )
312

313
        if pooling_params := request.pooling_params:
314
            supported_pooling_tasks = [
315
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
316
317
            ]

318
            if pooling_params.task not in supported_pooling_tasks:
319
320
321
322
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
323

324
        if request.kv_transfer_params is not None and (
325
326
327
328
329
330
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
331

332
        self.scheduler.add_request(request)
333

334
    def abort_requests(self, request_ids: list[str]):
335
336
337
338
339
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
340
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
341

342
343
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
344
        """Execute the model and log detailed info on failure."""
345
        try:
346
            yield
347
348
349
350
351
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

352
            # NOTE: This method is exception-free
353
354
355
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
356
357
            raise err

358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    @contextmanager
    def log_iteration_details(self, scheduler_output: SchedulerOutput):
        if not self.vllm_config.observability_config.enable_logging_iteration_details:
            yield
            return
        self._iteration_index = getattr(self, "_iteration_index", 0)
        iteration_details = compute_iteration_details(scheduler_output)
        before = time.monotonic()
        yield
        logger.info(
            "".join(
                [
                    "Iteration(",
                    str(self._iteration_index),
                    "): ",
                    str(iteration_details.num_ctx_requests),
                    " context requests, ",
                    str(iteration_details.num_ctx_tokens),
                    " context tokens, ",
                    str(iteration_details.num_generation_requests),
                    " generation requests, ",
                    str(iteration_details.num_generation_tokens),
                    " generation tokens, iteration elapsed time: ",
                    format((time.monotonic() - before) * 1000, ".2f"),
                    " ms",
                ]
            )
        )
        self._iteration_index += 1

388
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
389
390
391
392
393
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
394

395
396
397
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
398
            return {}, False
399
400
401
        scheduler_output = self.scheduler.schedule()
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
402
403
404
405
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
406
407
408
409
            model_output = future.result()
            if model_output is None:
                model_output = self.model_executor.sample_tokens(grammar_output)

410
411
412
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
413
414
415
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
416

417
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
418

419
    def post_step(self, model_executed: bool) -> None:
420
421
422
423
        # When using async scheduling we can't get draft token ids in advance,
        # so we update draft token ids in the worker process and don't
        # need to update draft token ids here.
        if not self.async_scheduling and self.use_spec_decode and model_executed:
424
425
426
427
428
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

429
    def step_with_batch_queue(
430
        self,
431
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
432
433
434
435
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
436
437
438
439
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
440
441
442
443
444
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
445

446
447
        batch_queue = self.batch_queue
        assert batch_queue is not None
448

449
450
451
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
452
        assert len(batch_queue) < self.batch_queue_size
453

454
        model_executed = False
455
        deferred_scheduler_output = None
456
        if self.scheduler.has_requests():
457
458
459
460
            scheduler_output = self.scheduler.schedule()
            exec_future = self.model_executor.execute_model(
                scheduler_output, non_block=True
            )
461
            if self.is_ec_consumer:
462
                model_executed = scheduler_output.total_num_scheduled_tokens > 0
463

464
            if self.is_pooling_model or not model_executed:
465
466
                # No sampling required (no requests scheduled).
                future = cast(Future[ModelRunnerOutput], exec_future)
467
            else:
468
469
470
                if not scheduler_output.pending_structured_output_tokens:
                    # We aren't waiting for any tokens, get any grammar output
                    # and sample immediately.
471
472
473
                    grammar_output = self.scheduler.get_grammar_bitmask(
                        scheduler_output
                    )
474
475
476
                    future = self.model_executor.sample_tokens(
                        grammar_output, non_block=True
                    )
477
                else:
478
479
480
481
482
                    # We need to defer sampling until we have processed the model output
                    # from the prior step.
                    deferred_scheduler_output = scheduler_output

            if not deferred_scheduler_output:
483
                # Add this step's future to the queue.
484
                batch_queue.appendleft((future, scheduler_output, exec_future))
485
486
487
488
489
490
491
492
                if (
                    model_executed
                    and len(batch_queue) < self.batch_queue_size
                    and not batch_queue[-1][0].done()
                ):
                    # Don't block on next worker response unless the queue is full
                    # or there are no more requests to schedule.
                    return None, True
493
494
495
496
497
498

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False
499
500

        # Block until the next result is available.
501
        future, scheduler_output, exec_model_fut = batch_queue.pop()
502
503
504
505
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
506
            model_output = future.result()
507
508
509
510
511
            if model_output is None:
                # None from sample_tokens() implies that the original execute_model()
                # call failed - raise that exception.
                exec_model_fut.result()
                raise RuntimeError("unexpected error")
512

513
514
515
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
516
517
518
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
519
520
521
522
523

        # NOTE(nick): We can either handle the deferred tasks here or save
        # in a field and do it immediately once step_with_batch_queue is
        # re-called. The latter slightly favors TTFT over TPOT/throughput.
        if deferred_scheduler_output:
524
525
526
527
528
529
530
531
532
533
534
535
            # If we are doing speculative decoding with structured output,
            # we need to get the draft token ids from the prior step before
            # we can compute the grammar bitmask for the deferred request.
            if self.use_spec_decode:
                draft_token_ids = self.model_executor.take_draft_token_ids()
                assert draft_token_ids is not None
                # Update the draft token ids in the scheduler output to
                # filter out the invalid spec tokens, which will be padded
                # with -1 and skipped by the grammar bitmask computation.
                self.scheduler.update_draft_token_ids_in_output(
                    draft_token_ids, deferred_scheduler_output
                )
536
537
538
539
540
541
            # We now have the tokens needed to compute the bitmask for the
            # deferred request. Get the bitmask and call sample tokens.
            grammar_output = self.scheduler.get_grammar_bitmask(
                deferred_scheduler_output
            )
            future = self.model_executor.sample_tokens(grammar_output, non_block=True)
542
            batch_queue.appendleft((future, deferred_scheduler_output, exec_future))
543

544
        return engine_core_outputs, model_executed
545

546
547
548
549
550
    def _process_aborts_queue(self):
        if not self.aborts_queue.empty():
            request_ids = []
            while not self.aborts_queue.empty():
                ids = self.aborts_queue.get_nowait()
551
552
                # Should be a list here, but also handle string just in case.
                request_ids.extend((ids,) if isinstance(ids, str) else ids)
553
554
555
            # More efficient to abort all as a single batch.
            self.abort_requests(request_ids)

556
    def shutdown(self):
557
        self.structured_output_manager.clear_backend()
558
559
        if self.model_executor:
            self.model_executor.shutdown()
560
561
        if self.scheduler:
            self.scheduler.shutdown()
562

563
564
    def profile(self, is_start: bool = True, profile_prefix: str | None = None):
        self.model_executor.profile(is_start, profile_prefix)
565

566
567
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
568
        # re-sync the internal caches (P0 sender, P1 receiver)
569
        if self.scheduler.has_unfinished_requests():
570
571
572
573
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
574

575
        # The cache either exists in EngineCore or WorkerWrapperBase
576
577
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
578

579
580
        self.model_executor.reset_mm_cache()

581
582
583
584
585
586
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.scheduler.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
587

588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
    def reset_encoder_cache(self) -> None:
        """Reset the encoder cache to invalidate all cached encoder outputs.

        This should be called when model weights are updated to ensure
        stale vision embeddings computed with old weights are not reused.
        Clears both the scheduler's cache manager and the GPU model runner's cache.
        """
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 sender, P1 receiver)
        if self.scheduler.has_unfinished_requests():
            logger.warning(
                "Resetting the encoder cache when requests are "
                "in progress may lead to desynced internal caches."
            )

        # Reset the scheduler's encoder cache manager (logical state)
        self.scheduler.reset_encoder_cache()
        # Reset the GPU model runner's encoder cache (physical storage)
        self.model_executor.reset_encoder_cache()

608
609
610
611
612
    def _reset_caches(self, reset_running_requests=True) -> None:
        self.reset_prefix_cache(reset_running_requests=reset_running_requests)
        self.reset_mm_cache()
        self.reset_encoder_cache()

613
614
    def pause_scheduler(
        self, mode: PauseMode = "abort", clear_cache: bool = True
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
    ) -> Future | None:
        """Pause generation; behavior depends on mode.

        All pause modes queue new adds -- "abort" and "keep" skip step();
        "wait" allows step() so in-flight requests can drain.

        - ``abort``: Set PAUSED_NEW, abort all requests, wait for abort
          outputs to be sent (when running with output_queue), optionally
          clear caches, then complete the returned Future.
        - ``wait``: Set PAUSED_NEW (queue adds, keep stepping); when drained,
          optionally clear caches, then complete the returned Future.
        - ``keep``: Set PAUSED_ALL; return a Future that completes when the
          output queue is empty.
        """
        if mode not in ("keep", "abort", "wait"):
            raise ValueError(f"Invalid pause mode: {mode}")
        if mode == "wait":
            raise ValueError("'wait' mode can't be used in inproc-engine mode")

        if mode == "abort":
            self.scheduler.finish_requests(None, RequestStatus.FINISHED_ABORTED)

        pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
        self.scheduler.set_pause_state(pause_state)
        if clear_cache:
            self._reset_caches()

642
643
644
        return None

    def resume_scheduler(self) -> None:
645
646
        """Resume the scheduler and flush any requests queued while paused."""
        self.scheduler.set_pause_state(PauseState.UNPAUSED)
647
648

    def is_scheduler_paused(self) -> bool:
649
650
        """Return whether the scheduler is in any pause state."""
        return self.scheduler.pause_state != PauseState.UNPAUSED
651

652
    def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None | Future:
653
654
655
656
657
658
659
660
        """Put the engine to sleep at the specified level.

        Args:
            level: Sleep level.
                - Level 0: Pause scheduling only. Requests are still accepted
                           but not processed. No GPU memory changes.
                - Level 1: Offload model weights to CPU, discard KV cache.
                - Level 2: Discard all GPU memory.
661
662
            mode: Pause mode - how to deal with any existing requests, see
                documentation of pause_scheduler method.
663
        """
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688

        # Pause scheduler before sleeping.
        clear_prefix_cache = level >= 1
        pause_future = self.pause_scheduler(mode=mode, clear_cache=clear_prefix_cache)
        if level < 1:
            return pause_future

        # Level 1+: Delegate to executor for GPU memory management
        model_executor = self.model_executor
        if pause_future is None:
            model_executor.sleep(level)
            return None

        future = Future[Any]()

        def pause_complete(f: Future):
            try:
                f.result()  # propagate any exception
                future.set_result(model_executor.sleep(level))
            except Exception as e:
                future.set_exception(e)

        logger.info("Waiting for in-flight requests to complete before sleeping...")
        pause_future.add_done_callback(pause_complete)
        return future
689

690
    def wake_up(self, tags: list[str] | None = None):
691
692
693
694
695
696
        """Wake up the engine from sleep.

        Args:
            tags: Tags to wake up. Use ["scheduling"] for level 0 wake up.
        """
        if tags is not None and "scheduling" in tags:
697
698
699
700
            # Remove "scheduling" from tags if there are other tags to process.
            tags = [t for t in tags if t != "scheduling"]

        if tags is None or tags:
701
            self.model_executor.wake_up(tags)
702

703
704
705
        # Resume scheduling (applies to all levels)
        self.resume_scheduler()

706
    def is_sleeping(self) -> bool:
707
        """Check if engine is sleeping at any level."""
708
        return self.is_scheduler_paused() or self.model_executor.is_sleeping
709

710
    def execute_dummy_batch(self):
711
        self.model_executor.execute_dummy_batch()
712

713
714
715
716
717
718
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

719
    def list_loras(self) -> set[int]:
720
721
722
723
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
724

725
726
727
    def save_sharded_state(
        self,
        path: str,
728
729
        pattern: str | None = None,
        max_size: int | None = None,
730
    ) -> None:
731
732
733
734
735
736
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
737
738
        method: str | Callable[..., _R],
        timeout: float | None = None,
739
        args: tuple = (),
740
        kwargs: dict[str, Any] | None = None,
741
742
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
743

744
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
745
        """Preprocess the request.
746

747
748
749
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
750
751
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
752
        # and will only be accessed in the input processing thread afterwards.
753
        if self.mm_receiver_cache is not None and request.mm_features:
754
755
756
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
757

758
        req = Request.from_engine_core_request(request, self.request_block_hasher)
759
760
761
762
763
764
765
766
767
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

768
769
770
771
772
773
774
775
776
777
    def _eep_scale_up_before_kv_init(self):
        raise NotImplementedError

    def _eep_send_engine_core_notification(
        self,
        notification_type: EEPNotificationType,
        vllm_config: VllmConfig | None = None,
    ):
        raise NotImplementedError

778
779
780
781

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

782
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
783
    addresses: EngineZmqAddresses
784

785
    @instrument(span_name="EngineCoreProc init")
786
787
    def __init__(
        self,
788
        vllm_config: VllmConfig,
789
        local_client: bool,
790
        handshake_address: str,
791
        executor_class: type[Executor],
792
        log_stats: bool,
793
        client_handshake_address: str | None = None,
794
        *,
795
        engine_index: int = 0,
796
    ):
Rui Qiao's avatar
Rui Qiao committed
797
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
798
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
799
        executor_fail_callback = lambda: self.input_queue.put_nowait(
800
801
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
802

Rui Qiao's avatar
Rui Qiao committed
803
804
805
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
806

807
808
809
810
811
812
813
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
814
            self.client_count = len(addresses.outputs)
815
816

            # Set up data parallel environment.
817
            self.has_coordinator = addresses.coordinator_output is not None
818
            self.frontend_stats_publish_address = (
819
820
821
822
823
824
825
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
826
            internal_dp_balancing = (
827
                self.has_coordinator
828
829
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
830
831
832
            # Only publish request queue stats to coordinator for "internal"
            # and "hybrid" LB modes.
            self.publish_dp_lb_stats = internal_dp_balancing
833

834
835
836
837
838
839
840
            self.addresses = addresses
            self.process_input_queue_block = True
            if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
                self._eep_send_engine_core_notification(
                    EEPNotificationType.NEW_CORE_ENGINES_INIT_READY,
                    vllm_config=vllm_config,
                )
841
842
            self._init_data_parallel(vllm_config)

843
            super().__init__(
844
845
846
847
848
                vllm_config,
                executor_class,
                log_stats,
                executor_fail_callback,
                internal_dp_balancing,
849
            )
850

851
852
853
854
855
856
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
857
858
859
860
861
862
863
864
865
866
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
867
868
869
870
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
871
872
873
874
875
876
877
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
878
879
880
881
882
883
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
884
                    raise RuntimeError("Input socket thread died during startup")
885
886
887
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

Rui Qiao's avatar
Rui Qiao committed
888
    @contextmanager
889
890
891
892
893
894
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
895
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
896
    ) -> Generator[EngineZmqAddresses, None, None]:
897
898
899
900
901
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

902
        For DP>1 with internal load-balancing this is with the shared front-end
903
904
        process which may reside on a different node.

905
        For DP>1 with external or hybrid load-balancing, two handshakes are
906
        performed:
907
908
909
910
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
911
912
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
913
914
915
916
917
918

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
919
        input_ctx = zmq.Context()
920
        is_local = local_client and client_handshake_address is None
921
        headless = not local_client
922
923
924
925
926
927
928
929
930
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
931
932
933
934
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
935
            assert local_client
936
            local_handshake = self._perform_handshake(
937
938
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
954
        headless: bool,
955
        vllm_config: VllmConfig,
956
        parallel_config_to_update: ParallelConfig | None = None,
957
    ) -> Generator[EngineZmqAddresses, None, None]:
958
959
960
961
962
963
964
965
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
966
            # Register engine with front-end.
967
968
969
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
            exc_during_init = False
            try:
                yield addresses
            except Exception:
                exc_during_init = True
                raise
            finally:
                if exc_during_init:
                    # Send FAILED status so the front-end detects init
                    # failure immediately via ZMQ instead of waiting for
                    # process sentinel (which may be delayed by cleanup).
                    with contextlib.suppress(Exception):
                        handshake_socket.send(
                            msgspec.msgpack.encode(
                                {
                                    "status": "FAILED",
                                    "local": local_client,
                                    "headless": headless,
                                }
                            )
                        )
                else:
                    # Send ready message.
                    num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
                    # We pass back the coordinator stats update address
                    # here for the external LB case for our colocated
                    # front-end to use (coordinator only runs with rank 0).
                    dp_stats_address = self.frontend_stats_publish_address

                    # Include config hash for DP configuration validation
                    ready_msg = {
                        "status": "READY",
                        "local": local_client,
                        "headless": headless,
                        "num_gpu_blocks": num_gpu_blocks,
                        "dp_stats_address": dp_stats_address,
                    }
                    if vllm_config.parallel_config.data_parallel_size > 1:
                        ready_msg["parallel_config_hash"] = (
                            vllm_config.parallel_config.compute_hash()
                        )

                    handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
1013

1014
    @staticmethod
1015
    def startup_handshake(
1016
1017
        handshake_socket: zmq.Socket,
        local_client: bool,
1018
        headless: bool,
1019
        parallel_config: ParallelConfig | None = None,
1020
    ) -> EngineZmqAddresses:
1021
        # Send registration message.
1022
        handshake_socket.send(
1023
1024
1025
1026
1027
1028
1029
1030
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
1031
1032

        # Receive initialization message.
1033
        logger.debug("Waiting for init message from front-end.")
1034
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
1035
1036
1037
1038
1039
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
1040
1041
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
1042
1043
            init_bytes, type=EngineHandshakeMetadata
        )
1044
1045
        logger.debug("Received init message: %s", init_message)

1046
1047
1048
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
1049

1050
        return init_message.addresses
1051
1052

    @staticmethod
1053
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
1054
1055
        """Launch EngineCore busy loop in background process."""

1056
1057
1058
1059
1060
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

1061
1062
1063
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

1074
        engine_core: EngineCoreProc | None = None
1075
        try:
1076
1077
1078
1079
1080
            vllm_config: VllmConfig = kwargs["vllm_config"]
            parallel_config: ParallelConfig = vllm_config.parallel_config
            data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
            if data_parallel:
                parallel_config.data_parallel_rank_local = local_dp_rank
1081
1082
1083
1084
1085
                maybe_init_worker_tracer(
                    instrumenting_module_name="vllm.engine_core",
                    process_kind="engine_core",
                    process_name=f"EngineCore_DP{dp_rank}",
                )
1086
                set_process_title("EngineCore", f"DP{dp_rank}")
1087
            else:
1088
1089
1090
1091
1092
                maybe_init_worker_tracer(
                    instrumenting_module_name="vllm.engine_core",
                    process_kind="engine_core",
                    process_name="EngineCore",
                )
1093
1094
1095
                set_process_title("EngineCore")
            decorate_logs()

1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
            if data_parallel and vllm_config.kv_transfer_config is not None:
                # modify the engine_id and append the local_dp_rank to it to ensure
                # that the kv_transfer_config is unique for each DP rank.
                vllm_config.kv_transfer_config.engine_id = (
                    f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
                )
                logger.debug(
                    "Setting kv_transfer_config.engine_id to %s",
                    vllm_config.kv_transfer_config.engine_id,
                )

1107
1108
            parallel_config.data_parallel_index = dp_rank
            if data_parallel and vllm_config.model_config.is_moe:
1109
1110
1111
1112
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
1113
1114
1115
1116
1117
1118
1119
                # Non-MoE DP ranks are completely independent, so treat like DP=1.
                # Note that parallel_config.data_parallel_index will still reflect
                # the original DP rank.
                parallel_config.data_parallel_size = 1
                parallel_config.data_parallel_size_local = 1
                parallel_config.data_parallel_rank = 0
                engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
1120

1121
            assert engine_core is not None
1122
1123
            engine_core.run_busy_loop()

1124
        except SystemExit:
1125
            logger.debug("EngineCore exiting.")
1126
            raise
1127
1128
1129
1130
1131
1132
1133
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
1134
1135
1136
1137
        finally:
            if engine_core is not None:
                engine_core.shutdown()

1138
1139
1140
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

1141
1142
1143
1144
1145
1146
1147
1148
    def has_work(self) -> bool:
        """Returns true if the engine should be stepped."""
        return (
            self.engines_running
            or self.scheduler.has_requests()
            or bool(self.batch_queue)
        )

1149
1150
1151
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

1152
1153
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
1154
            # 1) Poll the input queue until there is work to do.
1155
1156
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
1157
            self._process_engine_step()
1158
1159
1160
1161
1162

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
1163
1164
1165
        while not self.has_work():
            # Notify callbacks waiting for engine to become idle.
            self._notify_idle_state_callbacks()
1166
1167
1168
1169
1170
1171
1172
            if self.input_queue.empty():
                # Drain aborts queue; all aborts are also processed via input_queue.
                with self.aborts_queue.mutex:
                    self.aborts_queue.queue.clear()
                if logger.isEnabledFor(DEBUG):
                    logger.debug("EngineCore waiting for work.")
                    waited = True
1173
1174
1175
1176
1177
1178
1179
1180
            block = self.process_input_queue_block
            try:
                req = self.input_queue.get(block=block)
                self._handle_client_request(*req)
            except queue.Empty:
                break
            if not block:
                break
1181
1182

        if waited:
1183
            logger.debug("EngineCore loop active.")
1184
1185
1186
1187
1188
1189

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

1190
    def _process_engine_step(self) -> bool:
1191
1192
1193
        """Called only when there are unfinished local requests."""

        # Step the engine core.
1194
        outputs, model_executed = self.step_fn()
1195
        # Put EngineCoreOutputs into the output queue.
1196
        for output in outputs.items() if outputs else ():
1197
            self.output_queue.put_nowait(output)
1198
1199
        # Post-step hook.
        self.post_step(model_executed)
1200

1201
1202
1203
1204
1205
1206
1207
        # If no model execution happened but there are waiting requests
        # (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow
        # background threads (like NIXL handshake) to make progress.
        # Without this, the tight polling loop can starve background threads.
        if not model_executed and self.scheduler.has_unfinished_requests():
            time.sleep(0.001)

1208
1209
        return model_executed

1210
1211
1212
1213
    def _notify_idle_state_callbacks(self) -> None:
        while self._idle_state_callbacks:
            callback = self._idle_state_callbacks.pop()
            callback(self)
1214

1215
1216
1217
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1218
        """Dispatch request from client."""
1219

1220
        if request_type == EngineCoreRequestType.ADD:
1221
1222
            req, request_wave = request
            self.add_request(req, request_wave)
1223
        elif request_type == EngineCoreRequestType.ABORT:
1224
            self.abort_requests(request)
1225
        elif request_type == EngineCoreRequestType.UTILITY:
1226
            client_idx, call_id, method_name, args = request
1227
            output = UtilityOutput(call_id)
1228
1229
1230
1231
1232
1233
            # Lazily look-up utility method so that failure will be handled/returned.
            get_result = lambda: (method := getattr(self, method_name)) and method(
                *self._convert_msgspec_args(method, args)
            )
            enqueue_output = lambda out: self.output_queue.put_nowait(
                (client_idx, EngineCoreOutputs(utility_output=out))
1234
            )
1235
            self._invoke_utility_method(method_name, get_result, output, enqueue_output)
1236
1237
1238
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
1239
1240
1241
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
1242

1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
    @staticmethod
    def _invoke_utility_method(
        name: str, get_result: Callable, output: UtilityOutput, enqueue_output: Callable
    ):
        try:
            result = get_result()
            if isinstance(result, Future):
                # Defer utility output handling until future completion.
                callback = lambda future: EngineCoreProc._invoke_utility_method(
                    name, future.result, output, enqueue_output
                )
                result.add_done_callback(callback)
                return
            output.result = UtilityResult(result)
        except Exception as e:
            logger.exception("Invocation of %s method failed", name)
            output.failure_message = f"Call to {name} method failed: {str(e)}"
        enqueue_output(output)

1262
1263
1264
    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
1265
        arg type, try converting to msgspec object."""
1266
1267
1268
1269
1270
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
1271
1272
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
1273
            and issubclass(p.annotation, msgspec.Struct)
1274
1275
1276
1277
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
1278

1279
1280
1281
1282
1283
1284
1285
1286
1287
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
1288
1289
1290
1291
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
1292

1293
1294
1295
    def process_input_sockets(
        self,
        input_addresses: list[str],
1296
        coord_input_address: str | None,
1297
1298
1299
        identity: bytes,
        ready_event: threading.Event,
    ):
1300
1301
1302
        """Input socket IO thread."""

        # Msgpack serialization decoding.
1303
1304
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
1305

1306
1307
1308
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
1309
1310
1311
1312
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
1313
1314
1315
1316
1317
1318
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
1319
1320
1321
1322
1323
1324
1325
1326
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
1327
                # Send subscription message to coordinator.
1328
                coord_socket.send(b"\x01")
1329
1330
1331
1332
1333
1334
1335

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
1336
                input_socket.send(b"")
1337
                poller.register(input_socket, zmq.POLLIN)
1338

1339
            if coord_socket is not None:
1340
1341
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
1342
                poller.register(coord_socket, zmq.POLLIN)
1343

1344
1345
            ready_event.set()
            del ready_event
1346
1347
1348
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
1349
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
1350
1351
1352
1353
1354
                    # NOTE(yongji): ignore READY message sent by DP coordinator
                    # that is used to notify newly started engines
                    if type_frame.buffer == b"READY":
                        assert input_socket == coord_socket
                        continue
1355
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
1356
1357

                    # Deserialize the request data.
1358
                    request: Any
1359
                    if request_type == EngineCoreRequestType.ADD:
1360
1361
1362
1363
1364
1365
                        req: EngineCoreRequest = add_request_decoder.decode(data_frames)
                        try:
                            request = self.preprocess_add_request(req)
                        except Exception:
                            self._handle_request_preproc_error(req)
                            continue
1366
1367
                    else:
                        request = generic_decoder.decode(data_frames)
1368

1369
1370
1371
1372
1373
1374
1375
                        if request_type == EngineCoreRequestType.ABORT:
                            # Aborts are added to *both* queues, allows us to eagerly
                            # process aborts while also ensuring ordering in the input
                            # queue to avoid leaking requests. This is ok because
                            # aborting in the scheduler is idempotent.
                            self.aborts_queue.put_nowait(request)

1376
1377
1378
                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

1379
1380
1381
    def process_output_sockets(
        self,
        output_paths: list[str],
1382
        coord_output_path: str | None,
1383
1384
        engine_index: int,
    ):
1385
1386
1387
        """Output socket IO thread."""

        # Msgpack serialization encoding.
1388
        encoder = MsgpackEncoder()
1389
1390
1391
1392
1393
1394
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
1395

1396
1397
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1398
1399
1400
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1401
1402
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1403
1404
                for output_path in output_paths
            ]
1405
1406
1407
1408
1409
1410
1411
1412
1413
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1414
1415
            max_reuse_bufs = len(sockets) + 1

1416
            while True:
1417
1418
1419
1420
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1421
                    break
1422
1423
                assert not isinstance(output, bytes)
                client_index, outputs = output
1424
                outputs.engine_index = engine_index
1425

1426
1427
1428
1429
1430
1431
1432
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1433
1434
1435
1436
1437
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1438
                buffers = encoder.encode_into(outputs, buffer)
1439
1440
1441
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1442
1443
1444
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1445
1446
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1447
                    reuse_buffers.append(buffer)
1448

1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
    def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None:
        """Log and return a request-scoped error response for exceptions raised
        from the add request preprocessing in the input socket processing thread.
        """
        logger.exception(
            "Unexpected error pre-processing request %s", request.request_id
        )
        self.output_queue.put_nowait(
            (
                request.client_index,
                EngineCoreOutputs(
                    engine_index=self.engine_index,
                    finished_requests={request.request_id},
                    outputs=[
                        EngineCoreOutput(
                            request_id=request.request_id,
                            new_token_ids=[],
                            finish_reason=FinishReason.ERROR,
                        )
                    ],
                ),
            )
        )

1473
1474
1475
1476
1477
    def pause_scheduler(
        self, mode: PauseMode = "abort", clear_cache: bool = True
    ) -> Future | None:
        """Pause generation; behavior depends on mode.

1478
1479
1480
1481
1482
1483
1484
1485
1486
        All pause modes queue new adds -- "abort" and "keep" skip step();
        "wait" allows step() so in-flight requests can drain.

        - ``abort``: Set PAUSED_NEW, abort all requests, wait for abort
          outputs to be sent (when running with output_queue), optionally
          clear caches, then complete the returned Future.
        - ``wait``: Set PAUSED_NEW (queue adds, keep stepping); when drained,
          optionally clear caches, then complete the returned Future.
        - ``keep``: Set PAUSED_ALL; return a Future that completes when the
1487
1488
1489
1490
1491
          output queue is empty.
        """
        if mode not in ("keep", "abort", "wait"):
            raise ValueError(f"Invalid pause mode: {mode}")

1492
        def engine_idle_callback(engine: "EngineCoreProc", future: Future[Any]) -> None:
1493
            if clear_cache:
1494
                engine._reset_caches()
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
            future.set_result(None)

        if mode == "abort":
            aborted_reqs = self.scheduler.finish_requests(
                None, RequestStatus.FINISHED_ABORTED
            )
            self._send_abort_outputs(aborted_reqs)

        pause_state = PauseState.PAUSED_ALL if mode == "keep" else PauseState.PAUSED_NEW
        self.scheduler.set_pause_state(pause_state)
1505
1506
1507
1508
1509
1510
1511
1512
        if not self.has_work():
            if clear_cache:
                self._reset_caches()
            return None

        future = Future[Any]()
        self._idle_state_callbacks.append(partial(engine_idle_callback, future=future))
        return future
1513
1514

    def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None:
1515
        # TODO(nick) this will be moved inside the scheduler
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
        if aborted_reqs:
            # Map client_index to list of request_ids that belong to that client.
            by_client = defaultdict[int, set[str]](set)
            for req_id, client_index in aborted_reqs:
                by_client[client_index].add(req_id)
            for client_index, req_ids in by_client.items():
                outputs = [
                    EngineCoreOutput(req_id, [], finish_reason=FinishReason.ABORT)
                    for req_id in req_ids
                ]
                eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
                self.output_queue.put_nowait((client_index, eco))

1529
1530
1531
1532
1533
1534
1535
1536

class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1537
        local_client: bool,
1538
        handshake_address: str,
1539
1540
        executor_class: type[Executor],
        log_stats: bool,
1541
        client_handshake_address: str | None = None,
1542
    ):
1543
1544
1545
1546
        assert vllm_config.model_config.is_moe, (
            "DPEngineCoreProc should only be used for MoE models"
        )

1547
1548
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1549
        self.step_counter = 0
1550
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1551
        self.last_counts = (0, 0)
1552

1553
1554
1555
1556
        from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState

        self.eep_scaling_state: ElasticEPScalingState | None = None

1557
1558
        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1559
1560
1561
1562
1563
1564
1565
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
1566
            engine_index=dp_rank,
1567
        )
1568
1569
1570

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1571
1572
1573
1574
        parallel_config = vllm_config.parallel_config
        dp_rank = parallel_config.data_parallel_rank
        dp_size = parallel_config.data_parallel_size
        local_dp_rank = parallel_config.data_parallel_rank_local
1575
1576

        assert dp_size > 1
1577
        assert local_dp_rank is not None
1578
1579
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1580
        self.dp_rank = dp_rank
1581
1582
        dp_group, dp_store = parallel_config.stateless_init_dp_group(return_store=True)
        self.dp_group, self.dp_store = dp_group, dp_store
1583
1584
1585
1586
1587
1588

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1589
    def add_request(self, request: Request, request_wave: int = 0):
1590
        super().add_request(request, request_wave)
1591
1592
1593
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1594
1595
1596
1597
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1598
1599
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1600

1601
1602
    def resume_scheduler(self):
        super().resume_scheduler()
1603
1604
1605
1606
1607
        if (
            self.has_coordinator
            and not self.engines_running
            and self.scheduler.has_unfinished_requests()
        ):
1608
1609
1610
1611
            # Wake up other DP engines.
            self.output_queue.put_nowait(
                (-1, EngineCoreOutputs(start_wave=self.current_wave))
            )
1612

1613
1614
1615
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1616
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1617
1618
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1619
1620
                new_wave >= self.current_wave
            ):
1621
1622
                self.current_wave = new_wave
                if not self.engines_running:
1623
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1624
1625
1626
1627
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1628
    def _maybe_publish_request_counts(self):
1629
        if not self.publish_dp_lb_stats:
1630
1631
1632
1633
1634
1635
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1636
1637
1638
1639
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1640

1641
1642
1643
1644
1645
1646
1647
1648
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1649
1650
1651
1652
1653
1654
            if self.eep_scaling_state is not None:
                _ = self.eep_scaling_state.progress()
                if self.eep_scaling_state.is_complete():
                    self.process_input_queue_block = True
                    self.eep_scaling_state = None

1655
            executed = self._process_engine_step()
1656
            self._maybe_publish_request_counts()
1657

1658
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1659
1660
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1661
1662
1663
                    # All engines are idle.
                    continue

1664
1665
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1666
1667
1668
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1669
            self.engines_running = self._has_global_unfinished_reqs(
1670
1671
                local_unfinished_reqs
            )
1672

1673
            if not self.engines_running:
1674
                if self.dp_rank == 0 or not self.has_coordinator:
1675
                    # Notify client that we are pausing the loop.
1676
1677
1678
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1679
1680
1681
1682
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1683
                    self.output_queue.put_nowait(
1684
1685
1686
1687
1688
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1689
                # Increment wave count and reset step counter.
1690
                self.current_wave += 1
1691
                self.step_counter = 0
1692
1693

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1694
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1695
1696
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1697
1698
            return True

1699
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1700

1701
    def reinitialize_distributed(
1702
1703
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
        from copy import deepcopy

        from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState

        new_parallel_config = deepcopy(self.vllm_config.parallel_config)
        old_dp_size = new_parallel_config.data_parallel_size
        new_parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            new_parallel_config.data_parallel_rank = (
                reconfig_request.new_data_parallel_rank
            )
        new_parallel_config.data_parallel_master_ip = (
1719
            reconfig_request.new_data_parallel_master_ip
1720
        )
1721
        new_parallel_config.data_parallel_master_port = (
1722
            reconfig_request.new_data_parallel_master_port
1723
        )
1724
1725
        new_parallel_config._data_parallel_master_port_list = (
            reconfig_request.new_data_parallel_master_port_list
1726
        )
1727

1728
1729
        is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size
        is_shutdown = (
1730
1731
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
        )

        self.eep_scaling_state = ElasticEPScalingState(
            model_executor=self.model_executor,
            engine_core=self,
            vllm_config=self.vllm_config,
            new_parallel_config=new_parallel_config,
            worker_type="removing" if is_shutdown else "existing",
            scale_type="scale_down" if is_scale_down else "scale_up",
            reconfig_request=reconfig_request,
        )
        self.process_input_queue_block = False
        logger.info(
            "[Elastic EP] Received reconfiguration request and starting scaling up/down"
        )

    def _eep_send_engine_core_notification(
        self,
        notification_type: EEPNotificationType,
        vllm_config: VllmConfig | None = None,
    ):
        """
        Send notifications to EngineCoreClient, which can then forward
        the notifications to other engine core processes. It is used for:
Jiayi Yan's avatar
Jiayi Yan committed
1756
        1) In scale up: new core engines to notify existing core engines
1757
1758
1759
           that they are ready;
        2) In scale down: removing core engines to notify EngineCoreClient
           so EngineCoreClient can release their ray placement groups;
Jiayi Yan's avatar
Jiayi Yan committed
1760
        3) Both scale up/down: to notify EngineCoreClient that existing
1761
1762
1763
1764
           core engines have already switched to the new parallel setup.
        """
        if vllm_config is None:
            dp_rank = self.vllm_config.parallel_config.data_parallel_rank
1765
        else:
1766
1767
1768
1769
1770
1771
            dp_rank = vllm_config.parallel_config.data_parallel_rank
        notification_data = (notification_type.value, dp_rank)
        outputs = EngineCoreOutputs(
            utility_output=UtilityOutput(
                call_id=EEP_NOTIFICATION_CALL_ID,
                result=UtilityResult(notification_data),
1772
            )
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
        )
        outputs.engine_index = self.engine_index

        if hasattr(self, "output_thread") and self.output_thread.is_alive():
            self.output_queue.put_nowait((0, outputs))
        else:
            encoder = MsgpackEncoder()
            with (
                zmq.Context() as ctx,
                make_zmq_socket(
                    ctx, self.addresses.outputs[0], zmq.PUSH, linger=4000
                ) as socket,
            ):
                socket.send_multipart(encoder.encode(outputs))

    def eep_handle_engine_core_notification(
        self, notification_type: str | EEPNotificationType
    ):
        """
        Handle notification received from EngineCoreClient
        (forwarded from new core engines).
        """
        assert self.eep_scaling_state is not None
        if isinstance(notification_type, str):
            notification_type = EEPNotificationType(notification_type)
        self.eep_scaling_state.handle_notification(notification_type)

    def _eep_scale_up_before_kv_init(self):
        from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState

        self.eep_scaling_state = ElasticEPScalingState(
            model_executor=self.model_executor,
            engine_core=self,
            vllm_config=self.vllm_config,
            new_parallel_config=self.vllm_config.parallel_config,
            worker_type="new",
            scale_type="scale_up",
            reconfig_request=None,
        )
        self.model_executor.collective_rpc("init_device")
        self.model_executor.collective_rpc("load_model")
        self._eep_send_engine_core_notification(
            EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
        )
        self.model_executor.collective_rpc(
            "elastic_ep_execute", args=("receive_weights",)
        )
        self.available_gpu_memory_for_kv_cache = (
            ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1)
        )
        self.model_executor.collective_rpc(
            "elastic_ep_execute", args=("prepare_new_worker",)
        )
        self.process_input_queue_block = False
1827

Rui Qiao's avatar
Rui Qiao committed
1828

1829
class EngineCoreActorMixin:
Rui Qiao's avatar
Rui Qiao committed
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        addresses: EngineZmqAddresses,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
1841
1842
1843
1844
1845
1846
1847
        # Initialize tracer for distributed tracing if configured.
        maybe_init_worker_tracer(
            instrumenting_module_name="vllm.engine_core",
            process_kind="engine_core",
            process_name=f"DPEngineCoreActor_DP{dp_rank}",
        )

Rui Qiao's avatar
Rui Qiao committed
1848
        self.addresses = addresses
1849
        vllm_config.parallel_config.data_parallel_index = dp_rank
1850
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1851

1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1862
1863
1864
1865
1866
1867
1868
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1869
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1870

1871
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1872
        from vllm.platforms import current_platform
1873

1874
1875
1876
1877
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1878
1879
1880
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1881

1882
1883
1884
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1885
1886
1887
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1888
1889
1890
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1891
            os.environ[device_control_env_var] = value
1892
1893
1894
1895
1896
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1897
1898
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1899

Rui Qiao's avatar
Rui Qiao committed
1900
    @contextmanager
1901
1902
1903
1904
1905
1906
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1907
        client_handshake_address: str | None,
1908
    ):
Rui Qiao's avatar
Rui Qiao committed
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
1931
            self.run_busy_loop()  # type: ignore[attr-defined]
Rui Qiao's avatar
Rui Qiao committed
1932
1933
1934
1935
1936
1937
1938
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
            self.shutdown()  # type: ignore[attr-defined]


class DPMoEEngineCoreActor(EngineCoreActorMixin, DPEngineCoreProc):
    """Used for MoE model data parallel cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_rank = dp_rank

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        DPEngineCoreProc.__init__(
            self, vllm_config, local_client, "", executor_class, log_stats
        )


class EngineCoreActor(EngineCoreActorMixin, EngineCoreProc):
    """Used for non-MoE and/or non-DP cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_size = 1
        vllm_config.parallel_config.data_parallel_size_local = 1
        vllm_config.parallel_config.data_parallel_rank = 0

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        EngineCoreProc.__init__(
            self,
            vllm_config,
            local_client,
            "",
            executor_class,
            log_stats,
            engine_index=dp_rank,
        )