core.py 52.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import gc
4
import os
5
import queue
6
import signal
7
8
import threading
import time
9
from collections import deque
10
from collections.abc import Callable, Generator
11
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
12
from contextlib import ExitStack, contextmanager
13
from inspect import isclass, signature
14
from logging import DEBUG
15
from typing import Any, TypeVar
16

17
import msgspec
18
19
import zmq

20
21
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
22
from vllm.distributed.parallel_state import is_global_first_rank
23
from vllm.envs import enable_envs_cache
24
from vllm.logger import init_logger
25
from vllm.logging_utils.dump_input import dump_engine_exception
26
from vllm.lora.request import LoRARequest
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.cache import engine_receiver_cache_from_config
29
from vllm.tasks import POOLING_TASKS, SupportedTask
30
31
32
33
34
35
36
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils import (
    decorate_logs,
    get_hash_fn_by_name,
    make_zmq_socket,
    set_process_title,
)
37
from vllm.utils.gc_utils import maybe_attach_gc_debug_callback
38
from vllm.utils.import_utils import resolve_obj_by_qualname
39
40
41
42
43
44
45
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
46
from vllm.v1.core.sched.interface import SchedulerInterface
47
48
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from vllm.v1.engine import (
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
63
from vllm.v1.executor.abstract import Executor
64
from vllm.v1.kv_cache_interface import KVCacheConfig
65
from vllm.v1.metrics.stats import SchedulerStats
66
from vllm.v1.outputs import ModelRunnerOutput
67
from vllm.v1.request import Request, RequestStatus
68
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
69
from vllm.v1.structured_output import StructuredOutputManager
70
71
72
73
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

74
POLLING_TIMEOUT_S = 2.5
75
HANDSHAKE_TIMEOUT_MINS = 5
76

77
_R = TypeVar("_R")  # Return type for collective_rpc
78

79
80
81
82

class EngineCore:
    """Inner loop of vLLM's Engine."""

83
84
85
86
87
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
88
        executor_fail_callback: Callable | None = None,
89
    ):
90
91
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
92

93
94
        load_general_plugins()

95
        self.vllm_config = vllm_config
96
97
98
99
100
101
        if is_global_first_rank():
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
102

103
104
        self.log_stats = log_stats

105
106
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
107
        if executor_fail_callback is not None:
108
            self.model_executor.register_failure_callback(executor_fail_callback)
109

110
111
        self.available_gpu_memory_for_kv_cache = -1

112
        # Setup KV Caches and update CacheConfig after profiling.
113
114
115
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
116

117
118
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
119
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
120

121
122
        self.structured_output_manager = StructuredOutputManager(vllm_config)

123
        # Setup scheduler.
124
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
125
            Scheduler = resolve_obj_by_qualname(
126
127
                vllm_config.scheduler_config.scheduler_cls
            )
128
129
130
131
132
133
134
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
135
136
137
138
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
139
140
                vllm_config.scheduler_config.scheduler_cls,
            )
141

142
143
144
145
146
147
        if len(kv_cache_config.kv_cache_groups) == 0:
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
            logger.info("Disabling chunked prefill for model without KVCache")
            vllm_config.scheduler_config.chunked_prefill_enabled = False

148
149
150
151
152
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
        )

153
        self.scheduler: SchedulerInterface = Scheduler(
154
            vllm_config=vllm_config,
155
156
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
157
            include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
158
            log_stats=self.log_stats,
159
            block_size=scheduler_block_size,
160
        )
161
        self.use_spec_decode = vllm_config.speculative_config is not None
162
163
        if self.scheduler.connector is not None:  # type: ignore
            self.model_executor.init_kv_output_aggregator(
164
165
                self.scheduler.connector.get_finished_count()  # type: ignore
            )
166

167
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
168
        self.mm_receiver_cache = engine_receiver_cache_from_config(
169
170
            vllm_config, mm_registry
        )
171

172
173
174
175
176
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
177
178
179
        self.batch_queue: (
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None
        ) = None
180
        if self.batch_queue_size > 1:
181
            logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
182
            self.batch_queue = deque(maxlen=self.batch_queue_size)
183

184
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
185
186
187
188
        if (
            self.vllm_config.cache_config.enable_prefix_caching
            or self.scheduler.get_kv_connector() is not None
        ):
189
            caching_hash_fn = get_hash_fn_by_name(
190
191
                vllm_config.cache_config.prefix_caching_hash_algo
            )
192
193
194
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
195
                scheduler_block_size, caching_hash_fn
196
            )
197

198
199
200
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
201

202
    def _initialize_kv_caches(
203
204
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
205
        start = time.time()
206

207
        # Get all kv cache needed by the model
208
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
209

210
211
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
212
213
214
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
215
                self.available_gpu_memory_for_kv_cache = (
216
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
217
218
219
220
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
221
222
223
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
224
225
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
226
227
228
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
229

230
        assert len(kv_cache_specs) == len(available_gpu_memory)
231

232
233
234
235
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
236
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
237
        num_cpu_blocks = 0
238
239

        # Initialize kv cache and warmup the execution
240
        self.model_executor.initialize_from_config(kv_cache_configs)
241

242
        elapsed = time.time() - start
243
244
245
246
        logger.info(
            ("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
            elapsed,
        )
247
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
248

249
250
251
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

252
253
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
254

255
256
257
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
258
259
260
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
261
262
                f"request_id must be a string, got {type(request.request_id)}"
            )
263

264
        if pooling_params := request.pooling_params:
265
            supported_pooling_tasks = [
266
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
267
268
            ]

269
            if pooling_params.task not in supported_pooling_tasks:
270
271
272
273
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
274

275
        if request.kv_transfer_params is not None and (
276
277
278
279
280
281
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
282

283
        self.scheduler.add_request(request)
284

285
    def abort_requests(self, request_ids: list[str]):
286
287
288
289
290
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
291
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
292

293
294
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
295
        """Execute the model and log detailed info on failure."""
296
        try:
297
            yield
298
299
300
301
302
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

303
            # NOTE: This method is exception-free
304
305
306
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
307
308
            raise err

309
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
310
311
312
313
314
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
315

316
317
318
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
319
            return {}, False
320
        scheduler_output = self.scheduler.schedule()
321
322
323
324
325

        with self.log_error_detail(scheduler_output):
            model_output = self.model_executor.execute_model(scheduler_output)

        assert isinstance(model_output, ModelRunnerOutput)
326
        engine_core_outputs = self.scheduler.update_from_output(
327
            scheduler_output, model_output
328
        )
329

330
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
331

332
333
334
335
336
337
338
    def post_step(self, model_executed: bool) -> None:
        if self.use_spec_decode and model_executed:
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

339
    def step_with_batch_queue(
340
        self,
341
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
342
343
344
345
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
346
347
348
349
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
350
351
352
353
354
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
355
356
        batch_queue = self.batch_queue
        assert batch_queue is not None
357

358
359
360
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
361
        assert len(batch_queue) < self.batch_queue_size
362

363
364
365
        model_executed = False
        if self.scheduler.has_requests():
            scheduler_output = self.scheduler.schedule()
366
367
            future = self.model_executor.execute_model(scheduler_output, non_block=True)
            batch_queue.appendleft((future, scheduler_output))  # type: ignore[arg-type]
368
369

            model_executed = scheduler_output.total_num_scheduled_tokens > 0
370
371
372
373
374
            if (
                model_executed
                and len(batch_queue) < self.batch_queue_size
                and not batch_queue[-1][0].done()
            ):
375
376
377
378
379
380
381
382
383
384
385
386
                # Don't block on next worker response unless the queue is full
                # or there are no more requests to schedule.
                return None, True

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False

        # Block until the next result is available.
        future, scheduler_output = batch_queue.pop()
387
388
        with self.log_error_detail(scheduler_output):
            model_output = future.result()
389

390
        engine_core_outputs = self.scheduler.update_from_output(
391
392
            scheduler_output, model_output
        )
393
        return engine_core_outputs, model_executed
394

395
    def shutdown(self):
396
        self.structured_output_manager.clear_backend()
397
398
        if self.model_executor:
            self.model_executor.shutdown()
399
400
        if self.scheduler:
            self.scheduler.shutdown()
401

402
    def profile(self, is_start: bool = True):
403
        self.model_executor.profile(is_start)
404

405
406
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
407
        # re-sync the internal caches (P0 sender, P1 receiver)
408
        if self.scheduler.has_unfinished_requests():
409
410
411
412
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
413

414
        # The cache either exists in EngineCore or WorkerWrapperBase
415
416
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
417

418
419
        self.model_executor.reset_mm_cache()

420
421
422
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

423
424
425
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

426
    def wake_up(self, tags: list[str] | None = None):
427
        self.model_executor.wake_up(tags)
428

429
430
431
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

432
    def execute_dummy_batch(self):
433
        self.model_executor.execute_dummy_batch()
434

435
436
437
438
439
440
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

441
    def list_loras(self) -> set[int]:
442
443
444
445
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
446

447
448
449
    def save_sharded_state(
        self,
        path: str,
450
451
        pattern: str | None = None,
        max_size: int | None = None,
452
    ) -> None:
453
454
455
456
457
458
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
459
460
        method: str | Callable[..., _R],
        timeout: float | None = None,
461
        args: tuple = (),
462
        kwargs: dict[str, Any] | None = None,
463
464
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
465

466
467
468
469
470
    def save_tensorized_model(
        self,
        tensorizer_config,
    ) -> None:
        self.model_executor.save_tensorized_model(
471
472
            tensorizer_config=tensorizer_config,
        )
473

474
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
475
        """Preprocess the request.
476

477
478
479
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
480
481
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
482
        # and will only be accessed in the input processing thread afterwards.
483
        if self.mm_receiver_cache is not None and request.mm_features:
484
485
486
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
487

488
        req = Request.from_engine_core_request(request, self.request_block_hasher)
489
490
491
492
493
494
495
496
497
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

498
499
500
501

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

502
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
503

504
505
    def __init__(
        self,
506
        vllm_config: VllmConfig,
507
        local_client: bool,
508
        handshake_address: str,
509
        executor_class: type[Executor],
510
        log_stats: bool,
511
        client_handshake_address: str | None = None,
512
        engine_index: int = 0,
513
    ):
Rui Qiao's avatar
Rui Qiao committed
514
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
515
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
516
        executor_fail_callback = lambda: self.input_queue.put_nowait(
517
518
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
519

Rui Qiao's avatar
Rui Qiao committed
520
521
522
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
523

524
525
526
527
528
529
530
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
531
            self.client_count = len(addresses.outputs)
532
533

            # Set up data parallel environment.
534
            self.has_coordinator = addresses.coordinator_output is not None
535
            self.frontend_stats_publish_address = (
536
537
538
539
540
541
542
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
543
            # Only publish request queue stats to coordinator for "internal"
544
            # and "hybrid" LB modes .
545
546
            self.publish_dp_lb_stats = (
                self.has_coordinator
547
548
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
549

550
551
            self._init_data_parallel(vllm_config)

552
553
554
            super().__init__(
                vllm_config, executor_class, log_stats, executor_fail_callback
            )
555

556
557
558
559
560
561
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
562
563
564
565
566
567
568
569
570
571
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
572
573
574
575
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
576
577
578
579
580
581
582
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
583
584
585
586
587
588
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
589
                    raise RuntimeError("Input socket thread died during startup")
590
591
592
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

593
594
595
596
597
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        gc.collect()
        gc.freeze()

598
599
600
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()

601
602
603
604
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()

Rui Qiao's avatar
Rui Qiao committed
605
    @contextmanager
606
607
608
609
610
611
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
612
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
613
    ) -> Generator[EngineZmqAddresses, None, None]:
614
615
616
617
618
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

619
        For DP>1 with internal load-balancing this is with the shared front-end
620
621
        process which may reside on a different node.

622
        For DP>1 with external or hybrid load-balancing, two handshakes are
623
        performed:
624
625
626
627
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
628
629
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
630
631
632
633
634
635

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
636
        input_ctx = zmq.Context()
637
        is_local = local_client and client_handshake_address is None
638
        headless = not local_client
639
640
641
642
643
644
645
646
647
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
648
649
650
651
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
652
            assert local_client
653
            local_handshake = self._perform_handshake(
654
655
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
671
        headless: bool,
672
        vllm_config: VllmConfig,
673
        parallel_config_to_update: ParallelConfig | None = None,
674
    ) -> Generator[EngineZmqAddresses, None, None]:
675
676
677
678
679
680
681
682
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
683
            # Register engine with front-end.
684
685
686
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
687
688
689
690
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
691
692
693
694
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
695
696
697
698
699
700
701
702
703
704
705
706

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
707
                )
708
709

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
710

711
    @staticmethod
712
    def startup_handshake(
713
714
        handshake_socket: zmq.Socket,
        local_client: bool,
715
        headless: bool,
716
        parallel_config: ParallelConfig | None = None,
717
    ) -> EngineZmqAddresses:
718
        # Send registration message.
719
        handshake_socket.send(
720
721
722
723
724
725
726
727
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
728
729
730

        # Receive initialization message.
        logger.info("Waiting for init message from front-end.")
731
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
732
733
734
735
736
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
737
738
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
739
740
            init_bytes, type=EngineHandshakeMetadata
        )
741
742
        logger.debug("Received init message: %s", init_message)

743
744
745
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
746

747
        return init_message.addresses
748
749

    @staticmethod
750
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
751
752
        """Launch EngineCore busy loop in background process."""

753
754
755
756
757
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

758
759
760
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

761
762
763
764
765
766
767
768
769
770
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

771
        engine_core: EngineCoreProc | None = None
772
        try:
773
            parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
774
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
775
                set_process_title("EngineCore", f"DP{dp_rank}")
776
                decorate_logs()
777
778
779
780
781
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
782
                set_process_title("EngineCore")
783
                decorate_logs()
784
785
                engine_core = EngineCoreProc(*args, **kwargs)

786
787
            engine_core.run_busy_loop()

788
        except SystemExit:
789
            logger.debug("EngineCore exiting.")
790
            raise
791
792
793
794
795
796
797
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
798
799
800
801
        finally:
            if engine_core is not None:
                engine_core.shutdown()

802
803
804
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

805
806
807
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

808
809
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
810
            # 1) Poll the input queue until there is work to do.
811
812
813
814
815
816
817
818
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
819
820
821
822
823
        while (
            not self.engines_running
            and not self.scheduler.has_requests()
            and not self.batch_queue
        ):
824
825
826
827
828
829
830
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
831
            logger.debug("EngineCore loop active.")
832
833
834
835
836
837

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

838
    def _process_engine_step(self) -> bool:
839
840
841
        """Called only when there are unfinished local requests."""

        # Step the engine core.
842
        outputs, model_executed = self.step_fn()
843
        # Put EngineCoreOutputs into the output queue.
844
        for output in outputs.items() if outputs else ():
845
            self.output_queue.put_nowait(output)
846
847
        # Post-step hook.
        self.post_step(model_executed)
848

849
850
        return model_executed

851
852
853
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
854
        """Dispatch request from client."""
855

856
        if request_type == EngineCoreRequestType.ADD:
857
858
            req, request_wave = request
            self.add_request(req, request_wave)
859
        elif request_type == EngineCoreRequestType.ABORT:
860
            self.abort_requests(request)
861
        elif request_type == EngineCoreRequestType.UTILITY:
862
            client_idx, call_id, method_name, args = request
863
864
865
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
866
867
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
868
869
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
870
871
872
                output.failure_message = (
                    f"Call to {method_name} method failed: {str(e)}"
                )
873
            self.output_queue.put_nowait(
874
875
                (client_idx, EngineCoreOutputs(utility_output=output))
            )
876
877
878
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
879
880
881
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
882
883
884
885

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
886
        arg type, try converting to msgspec object."""
887
888
889
890
891
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
892
893
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
894
            and issubclass(p.annotation, msgspec.Struct)
895
896
897
898
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
899

900
901
902
903
904
905
906
907
908
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
909
910
911
912
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
913

914
915
916
    def process_input_sockets(
        self,
        input_addresses: list[str],
917
        coord_input_address: str | None,
918
919
920
        identity: bytes,
        ready_event: threading.Event,
    ):
921
922
923
        """Input socket IO thread."""

        # Msgpack serialization decoding.
924
925
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
926

927
928
929
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
930
931
932
933
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
934
935
936
937
938
939
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
940
941
942
943
944
945
946
947
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
948
                # Send subscription message to coordinator.
949
                coord_socket.send(b"\x01")
950
951
952
953
954
955
956

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
957
                input_socket.send(b"")
958
                poller.register(input_socket, zmq.POLLIN)
959

960
            if coord_socket is not None:
961
962
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
963
                poller.register(coord_socket, zmq.POLLIN)
964

965
966
            ready_event.set()
            del ready_event
967
968
969
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
970
971
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
972
973

                    # Deserialize the request data.
974
975
976
977
978
                    if request_type == EngineCoreRequestType.ADD:
                        request = add_request_decoder.decode(data_frames)
                        request = self.preprocess_add_request(request)
                    else:
                        request = generic_decoder.decode(data_frames)
979
980
981
982

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

983
984
985
    def process_output_sockets(
        self,
        output_paths: list[str],
986
        coord_output_path: str | None,
987
988
        engine_index: int,
    ):
989
990
991
        """Output socket IO thread."""

        # Msgpack serialization encoding.
992
        encoder = MsgpackEncoder()
993
994
995
996
997
998
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
999

1000
1001
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1002
1003
1004
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1005
1006
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1007
1008
                for output_path in output_paths
            ]
1009
1010
1011
1012
1013
1014
1015
1016
1017
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1018
1019
            max_reuse_bufs = len(sockets) + 1

1020
            while True:
1021
1022
1023
1024
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1025
                    break
1026
1027
                assert not isinstance(output, bytes)
                client_index, outputs = output
1028
                outputs.engine_index = engine_index
1029

1030
1031
1032
1033
1034
1035
1036
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1037
1038
1039
1040
1041
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1042
                buffers = encoder.encode_into(outputs, buffer)
1043
1044
1045
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1046
1047
1048
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1049
1050
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1051
                    reuse_buffers.append(buffer)
1052
1053
1054
1055
1056
1057
1058
1059
1060


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1061
        local_client: bool,
1062
        handshake_address: str,
1063
1064
        executor_class: type[Executor],
        log_stats: bool,
1065
        client_handshake_address: str | None = None,
1066
    ):
1067
1068
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1069
        self.step_counter = 0
1070
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1071
        self.last_counts = (0, 0)
1072
1073
1074

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1075
1076
1077
1078
1079
1080
1081
1082
1083
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
            dp_rank,
        )
1084
1085
1086

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1087
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1088
        dp_size = vllm_config.parallel_config.data_parallel_size
1089
1090
1091
1092
1093
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1094
1095
1096
1097
1098
1099
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
1100
1101
1102
1103
            logger.debug(
                "Setting kv_transfer_config.engine_id to %s",
                vllm_config.kv_transfer_config.engine_id,
            )
1104

1105
        self.dp_rank = dp_rank
1106
1107
1108
1109
1110
1111
1112
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1113
1114
1115
1116
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1117
1118
1119
1120
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1121
1122
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1123

1124
        super().add_request(request, request_wave)
1125

1126
1127
1128
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1129
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1130
1131
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1132
1133
                new_wave >= self.current_wave
            ):
1134
1135
                self.current_wave = new_wave
                if not self.engines_running:
1136
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1137
1138
1139
1140
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1141
    def _maybe_publish_request_counts(self):
1142
        if not self.publish_dp_lb_stats:
1143
1144
1145
1146
1147
1148
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1149
1150
1151
1152
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1153

1154
1155
1156
1157
1158
1159
1160
1161
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1162
1163
            # 2) Step the engine core.
            executed = self._process_engine_step()
1164
1165
            self._maybe_publish_request_counts()

1166
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1167
1168
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1169
1170
1171
                    # All engines are idle.
                    continue

1172
1173
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1174
1175
1176
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1177
            self.engines_running = self._has_global_unfinished_reqs(
1178
1179
                local_unfinished_reqs
            )
1180

1181
            if not self.engines_running:
1182
                if self.dp_rank == 0 or not self.has_coordinator:
1183
                    # Notify client that we are pausing the loop.
1184
1185
1186
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1187
1188
1189
1190
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1191
                    self.output_queue.put_nowait(
1192
1193
1194
1195
1196
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1197
                # Increment wave count and reset step counter.
1198
                self.current_wave += 1
1199
                self.step_counter = 0
1200
1201

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1202
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1203
1204
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1205
1206
            return True

1207
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1208

1209
    def reinitialize_distributed(
1210
1211
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1212
1213
1214
1215
1216
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1217
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1218
        if reconfig_request.new_data_parallel_rank != -1:
1219
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1220
        # local rank specifies device visibility, it should not be changed
1221
1222
1223
1224
1225
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1226
            reconfig_request.new_data_parallel_master_ip
1227
1228
        )
        parallel_config.data_parallel_master_port = (
1229
            reconfig_request.new_data_parallel_master_port
1230
        )
1231
1232
1233
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1234
        reconfig_request.new_data_parallel_master_port = (
1235
            parallel_config.data_parallel_master_port
1236
        )
1237
1238
1239
1240
1241
1242
1243
1244

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1245
1246
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1247
1248
1249
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1250
1251
1252
1253
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1254
1255
1256
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1257
1258
1259
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1260

Rui Qiao's avatar
Rui Qiao committed
1261
1262
1263
1264
1265
1266
1267
1268
1269

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1270
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1271
1272
1273
1274
1275
1276
1277
1278
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
1279
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1280

1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1291
1292
1293
1294
1295
1296
1297
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1298
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1299

1300
        super().__init__(vllm_config, local_client, "", executor_class, log_stats)
Rui Qiao's avatar
Rui Qiao committed
1301

1302
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1303
        from vllm.platforms import current_platform
1304

1305
1306
1307
1308
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1309
1310
1311
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1312

1313
1314
1315
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1316
1317
1318
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1319
1320
1321
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1322
            os.environ[device_control_env_var] = value
1323
1324
1325
1326
1327
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1328
1329
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1330

Rui Qiao's avatar
Rui Qiao committed
1331
    @contextmanager
1332
1333
1334
1335
1336
1337
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1338
        client_handshake_address: str | None,
1339
    ):
Rui Qiao's avatar
Rui Qiao committed
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()