core.py 52.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import gc
4
import os
5
import queue
6
import signal
7
8
import threading
import time
9
from collections import deque
10
from collections.abc import Callable, Generator
11
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
12
from contextlib import ExitStack, contextmanager
13
from inspect import isclass, signature
14
from logging import DEBUG
15
from typing import Any, TypeVar
16

17
import msgspec
18
19
import zmq

20
21
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
22
from vllm.distributed.parallel_state import is_global_first_rank
23
from vllm.envs import enable_envs_cache
24
from vllm.logger import init_logger
25
from vllm.logging_utils.dump_input import dump_engine_exception
26
from vllm.lora.request import LoRARequest
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.cache import engine_receiver_cache_from_config
29
from vllm.tasks import POOLING_TASKS, SupportedTask
30
31
32
33
34
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils import (
    decorate_logs,
    set_process_title,
)
35
from vllm.utils.gc_utils import maybe_attach_gc_debug_callback
36
from vllm.utils.hashing import get_hash_fn_by_name
37
from vllm.utils.import_utils import resolve_obj_by_qualname
38
from vllm.utils.network_utils import make_zmq_socket
39
40
41
42
43
44
45
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
46
from vllm.v1.core.sched.interface import SchedulerInterface
47
48
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from vllm.v1.engine import (
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
63
from vllm.v1.executor import Executor
64
from vllm.v1.kv_cache_interface import KVCacheConfig
65
from vllm.v1.metrics.stats import SchedulerStats
66
from vllm.v1.outputs import ModelRunnerOutput
67
from vllm.v1.request import Request, RequestStatus
68
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
69
from vllm.v1.structured_output import StructuredOutputManager
70
71
72
73
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

74
POLLING_TIMEOUT_S = 2.5
75
HANDSHAKE_TIMEOUT_MINS = 5
76

77
_R = TypeVar("_R")  # Return type for collective_rpc
78

79
80
81
82

class EngineCore:
    """Inner loop of vLLM's Engine."""

83
84
85
86
87
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
88
        executor_fail_callback: Callable | None = None,
89
    ):
90
91
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
92

93
94
        load_general_plugins()

95
        self.vllm_config = vllm_config
96
97
98
99
100
101
        if is_global_first_rank():
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
102

103
104
        self.log_stats = log_stats

105
106
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
107
        if executor_fail_callback is not None:
108
            self.model_executor.register_failure_callback(executor_fail_callback)
109

110
111
        self.available_gpu_memory_for_kv_cache = -1

112
        # Setup KV Caches and update CacheConfig after profiling.
113
114
115
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
116

117
118
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
119
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
120

121
122
        self.structured_output_manager = StructuredOutputManager(vllm_config)

123
        # Setup scheduler.
124
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
125
            Scheduler = resolve_obj_by_qualname(
126
127
                vllm_config.scheduler_config.scheduler_cls
            )
128
129
130
131
132
133
134
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
135
136
137
138
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
139
140
                vllm_config.scheduler_config.scheduler_cls,
            )
141

142
143
144
145
146
147
        if len(kv_cache_config.kv_cache_groups) == 0:
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
            logger.info("Disabling chunked prefill for model without KVCache")
            vllm_config.scheduler_config.chunked_prefill_enabled = False

148
149
150
151
152
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
        )

153
        self.scheduler: SchedulerInterface = Scheduler(
154
            vllm_config=vllm_config,
155
156
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
157
            include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
158
            log_stats=self.log_stats,
159
            block_size=scheduler_block_size,
160
        )
161
        self.use_spec_decode = vllm_config.speculative_config is not None
162
        if self.scheduler.connector is not None:  # type: ignore
163
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
164

165
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
166
        self.mm_receiver_cache = engine_receiver_cache_from_config(
167
168
            vllm_config, mm_registry
        )
169

170
171
172
173
174
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
175
176
177
        self.batch_queue: (
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None
        ) = None
178
        if self.batch_queue_size > 1:
179
            logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
180
            self.batch_queue = deque(maxlen=self.batch_queue_size)
181

182
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
183
184
185
186
        if (
            self.vllm_config.cache_config.enable_prefix_caching
            or self.scheduler.get_kv_connector() is not None
        ):
187
            caching_hash_fn = get_hash_fn_by_name(
188
189
                vllm_config.cache_config.prefix_caching_hash_algo
            )
190
191
192
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
193
                scheduler_block_size, caching_hash_fn
194
            )
195

196
197
198
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
199

200
    def _initialize_kv_caches(
201
202
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
203
        start = time.time()
204

205
        # Get all kv cache needed by the model
206
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
207

208
209
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
210
211
212
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
213
                self.available_gpu_memory_for_kv_cache = (
214
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
215
216
217
218
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
219
220
221
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
222
223
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
224
225
226
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
227

228
        assert len(kv_cache_specs) == len(available_gpu_memory)
229

230
231
232
233
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
234
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
235
        num_cpu_blocks = 0
236
237

        # Initialize kv cache and warmup the execution
238
        self.model_executor.initialize_from_config(kv_cache_configs)
239

240
        elapsed = time.time() - start
241
242
243
244
        logger.info(
            ("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
            elapsed,
        )
245
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
246

247
248
249
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

250
251
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
252

253
254
255
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
256
257
258
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
259
260
                f"request_id must be a string, got {type(request.request_id)}"
            )
261

262
        if pooling_params := request.pooling_params:
263
            supported_pooling_tasks = [
264
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
265
266
            ]

267
            if pooling_params.task not in supported_pooling_tasks:
268
269
270
271
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
272

273
        if request.kv_transfer_params is not None and (
274
275
276
277
278
279
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
280

281
        self.scheduler.add_request(request)
282

283
    def abort_requests(self, request_ids: list[str]):
284
285
286
287
288
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
289
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
290

291
292
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
293
        """Execute the model and log detailed info on failure."""
294
        try:
295
            yield
296
297
298
299
300
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

301
            # NOTE: This method is exception-free
302
303
304
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
305
306
            raise err

307
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
308
309
310
311
312
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
313

314
315
316
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
317
            return {}, False
318
        scheduler_output = self.scheduler.schedule()
319
320
321
322

        with self.log_error_detail(scheduler_output):
            model_output = self.model_executor.execute_model(scheduler_output)

323
        engine_core_outputs = self.scheduler.update_from_output(
324
            scheduler_output, model_output
325
        )
326

327
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
328

329
330
331
332
333
334
335
    def post_step(self, model_executed: bool) -> None:
        if self.use_spec_decode and model_executed:
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

336
    def step_with_batch_queue(
337
        self,
338
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
339
340
341
342
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
343
344
345
346
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
347
348
349
350
351
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
352
353
        batch_queue = self.batch_queue
        assert batch_queue is not None
354

355
356
357
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
358
        assert len(batch_queue) < self.batch_queue_size
359

360
361
362
        model_executed = False
        if self.scheduler.has_requests():
            scheduler_output = self.scheduler.schedule()
363
            future = self.model_executor.execute_model(scheduler_output, non_block=True)
364
            batch_queue.appendleft((future, scheduler_output))
365
366

            model_executed = scheduler_output.total_num_scheduled_tokens > 0
367
368
369
370
371
            if (
                model_executed
                and len(batch_queue) < self.batch_queue_size
                and not batch_queue[-1][0].done()
            ):
372
373
374
375
376
377
378
379
380
381
382
383
                # Don't block on next worker response unless the queue is full
                # or there are no more requests to schedule.
                return None, True

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False

        # Block until the next result is available.
        future, scheduler_output = batch_queue.pop()
384
385
        with self.log_error_detail(scheduler_output):
            model_output = future.result()
386

387
        engine_core_outputs = self.scheduler.update_from_output(
388
389
            scheduler_output, model_output
        )
390
        return engine_core_outputs, model_executed
391

392
    def shutdown(self):
393
        self.structured_output_manager.clear_backend()
394
395
        if self.model_executor:
            self.model_executor.shutdown()
396
397
        if self.scheduler:
            self.scheduler.shutdown()
398

399
    def profile(self, is_start: bool = True):
400
        self.model_executor.profile(is_start)
401

402
403
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
404
        # re-sync the internal caches (P0 sender, P1 receiver)
405
        if self.scheduler.has_unfinished_requests():
406
407
408
409
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
410

411
        # The cache either exists in EngineCore or WorkerWrapperBase
412
413
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
414

415
416
        self.model_executor.reset_mm_cache()

417
418
419
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

420
421
422
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

423
    def wake_up(self, tags: list[str] | None = None):
424
        self.model_executor.wake_up(tags)
425

426
427
428
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

429
    def execute_dummy_batch(self):
430
        self.model_executor.execute_dummy_batch()
431

432
433
434
435
436
437
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

438
    def list_loras(self) -> set[int]:
439
440
441
442
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
443

444
445
446
    def save_sharded_state(
        self,
        path: str,
447
448
        pattern: str | None = None,
        max_size: int | None = None,
449
    ) -> None:
450
451
452
453
454
455
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
456
457
        method: str | Callable[..., _R],
        timeout: float | None = None,
458
        args: tuple = (),
459
        kwargs: dict[str, Any] | None = None,
460
461
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
462

463
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
464
        """Preprocess the request.
465

466
467
468
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
469
470
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
471
        # and will only be accessed in the input processing thread afterwards.
472
        if self.mm_receiver_cache is not None and request.mm_features:
473
474
475
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
476

477
        req = Request.from_engine_core_request(request, self.request_block_hasher)
478
479
480
481
482
483
484
485
486
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

487
488
489
490

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

491
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
492

493
494
    def __init__(
        self,
495
        vllm_config: VllmConfig,
496
        local_client: bool,
497
        handshake_address: str,
498
        executor_class: type[Executor],
499
        log_stats: bool,
500
        client_handshake_address: str | None = None,
501
        engine_index: int = 0,
502
    ):
Rui Qiao's avatar
Rui Qiao committed
503
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
504
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
505
        executor_fail_callback = lambda: self.input_queue.put_nowait(
506
507
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
508

Rui Qiao's avatar
Rui Qiao committed
509
510
511
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
512

513
514
515
516
517
518
519
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
520
            self.client_count = len(addresses.outputs)
521
522

            # Set up data parallel environment.
523
            self.has_coordinator = addresses.coordinator_output is not None
524
            self.frontend_stats_publish_address = (
525
526
527
528
529
530
531
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
532
            # Only publish request queue stats to coordinator for "internal"
533
            # and "hybrid" LB modes .
534
535
            self.publish_dp_lb_stats = (
                self.has_coordinator
536
537
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
538

539
540
            self._init_data_parallel(vllm_config)

541
542
543
            super().__init__(
                vllm_config, executor_class, log_stats, executor_fail_callback
            )
544

545
546
547
548
549
550
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
551
552
553
554
555
556
557
558
559
560
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
561
562
563
564
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
565
566
567
568
569
570
571
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
572
573
574
575
576
577
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
578
                    raise RuntimeError("Input socket thread died during startup")
579
580
581
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

582
583
584
585
586
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        gc.collect()
        gc.freeze()

587
588
589
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()

590
591
592
593
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()

Rui Qiao's avatar
Rui Qiao committed
594
    @contextmanager
595
596
597
598
599
600
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
601
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
602
    ) -> Generator[EngineZmqAddresses, None, None]:
603
604
605
606
607
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

608
        For DP>1 with internal load-balancing this is with the shared front-end
609
610
        process which may reside on a different node.

611
        For DP>1 with external or hybrid load-balancing, two handshakes are
612
        performed:
613
614
615
616
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
617
618
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
619
620
621
622
623
624

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
625
        input_ctx = zmq.Context()
626
        is_local = local_client and client_handshake_address is None
627
        headless = not local_client
628
629
630
631
632
633
634
635
636
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
637
638
639
640
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
641
            assert local_client
642
            local_handshake = self._perform_handshake(
643
644
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
660
        headless: bool,
661
        vllm_config: VllmConfig,
662
        parallel_config_to_update: ParallelConfig | None = None,
663
    ) -> Generator[EngineZmqAddresses, None, None]:
664
665
666
667
668
669
670
671
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
672
            # Register engine with front-end.
673
674
675
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
676
677
678
679
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
680
681
682
683
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
684
685
686
687
688
689
690
691
692
693
694
695

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
696
                )
697
698

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
699

700
    @staticmethod
701
    def startup_handshake(
702
703
        handshake_socket: zmq.Socket,
        local_client: bool,
704
        headless: bool,
705
        parallel_config: ParallelConfig | None = None,
706
    ) -> EngineZmqAddresses:
707
        # Send registration message.
708
        handshake_socket.send(
709
710
711
712
713
714
715
716
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
717
718
719

        # Receive initialization message.
        logger.info("Waiting for init message from front-end.")
720
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
721
722
723
724
725
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
726
727
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
728
729
            init_bytes, type=EngineHandshakeMetadata
        )
730
731
        logger.debug("Received init message: %s", init_message)

732
733
734
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
735

736
        return init_message.addresses
737
738

    @staticmethod
739
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
740
741
        """Launch EngineCore busy loop in background process."""

742
743
744
745
746
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

747
748
749
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

750
751
752
753
754
755
756
757
758
759
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

760
        engine_core: EngineCoreProc | None = None
761
        try:
762
            parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
763
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
764
                set_process_title("EngineCore", f"DP{dp_rank}")
765
                decorate_logs()
766
767
768
769
770
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
771
                set_process_title("EngineCore")
772
                decorate_logs()
773
774
                engine_core = EngineCoreProc(*args, **kwargs)

775
776
            engine_core.run_busy_loop()

777
        except SystemExit:
778
            logger.debug("EngineCore exiting.")
779
            raise
780
781
782
783
784
785
786
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
787
788
789
790
        finally:
            if engine_core is not None:
                engine_core.shutdown()

791
792
793
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

794
795
796
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

797
798
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
799
            # 1) Poll the input queue until there is work to do.
800
801
802
803
804
805
806
807
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
808
809
810
811
812
        while (
            not self.engines_running
            and not self.scheduler.has_requests()
            and not self.batch_queue
        ):
813
814
815
816
817
818
819
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
820
            logger.debug("EngineCore loop active.")
821
822
823
824
825
826

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

827
    def _process_engine_step(self) -> bool:
828
829
830
        """Called only when there are unfinished local requests."""

        # Step the engine core.
831
        outputs, model_executed = self.step_fn()
832
        # Put EngineCoreOutputs into the output queue.
833
        for output in outputs.items() if outputs else ():
834
            self.output_queue.put_nowait(output)
835
836
        # Post-step hook.
        self.post_step(model_executed)
837

838
839
        return model_executed

840
841
842
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
843
        """Dispatch request from client."""
844

845
        if request_type == EngineCoreRequestType.ADD:
846
847
            req, request_wave = request
            self.add_request(req, request_wave)
848
        elif request_type == EngineCoreRequestType.ABORT:
849
            self.abort_requests(request)
850
        elif request_type == EngineCoreRequestType.UTILITY:
851
            client_idx, call_id, method_name, args = request
852
853
854
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
855
856
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
857
858
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
859
860
861
                output.failure_message = (
                    f"Call to {method_name} method failed: {str(e)}"
                )
862
            self.output_queue.put_nowait(
863
864
                (client_idx, EngineCoreOutputs(utility_output=output))
            )
865
866
867
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
868
869
870
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
871
872
873
874

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
875
        arg type, try converting to msgspec object."""
876
877
878
879
880
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
881
882
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
883
            and issubclass(p.annotation, msgspec.Struct)
884
885
886
887
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
888

889
890
891
892
893
894
895
896
897
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
898
899
900
901
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
902

903
904
905
    def process_input_sockets(
        self,
        input_addresses: list[str],
906
        coord_input_address: str | None,
907
908
909
        identity: bytes,
        ready_event: threading.Event,
    ):
910
911
912
        """Input socket IO thread."""

        # Msgpack serialization decoding.
913
914
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
915

916
917
918
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
919
920
921
922
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
923
924
925
926
927
928
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
929
930
931
932
933
934
935
936
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
937
                # Send subscription message to coordinator.
938
                coord_socket.send(b"\x01")
939
940
941
942
943
944
945

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
946
                input_socket.send(b"")
947
                poller.register(input_socket, zmq.POLLIN)
948

949
            if coord_socket is not None:
950
951
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
952
                poller.register(coord_socket, zmq.POLLIN)
953

954
955
            ready_event.set()
            del ready_event
956
957
958
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
959
960
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
961
962

                    # Deserialize the request data.
963
964
965
966
967
                    if request_type == EngineCoreRequestType.ADD:
                        request = add_request_decoder.decode(data_frames)
                        request = self.preprocess_add_request(request)
                    else:
                        request = generic_decoder.decode(data_frames)
968
969
970
971

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

972
973
974
    def process_output_sockets(
        self,
        output_paths: list[str],
975
        coord_output_path: str | None,
976
977
        engine_index: int,
    ):
978
979
980
        """Output socket IO thread."""

        # Msgpack serialization encoding.
981
        encoder = MsgpackEncoder()
982
983
984
985
986
987
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
988

989
990
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
991
992
993
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
994
995
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
996
997
                for output_path in output_paths
            ]
998
999
1000
1001
1002
1003
1004
1005
1006
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1007
1008
            max_reuse_bufs = len(sockets) + 1

1009
            while True:
1010
1011
1012
1013
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1014
                    break
1015
1016
                assert not isinstance(output, bytes)
                client_index, outputs = output
1017
                outputs.engine_index = engine_index
1018

1019
1020
1021
1022
1023
1024
1025
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1026
1027
1028
1029
1030
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1031
                buffers = encoder.encode_into(outputs, buffer)
1032
1033
1034
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1035
1036
1037
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1038
1039
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1040
                    reuse_buffers.append(buffer)
1041
1042
1043
1044
1045
1046
1047
1048
1049


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1050
        local_client: bool,
1051
        handshake_address: str,
1052
1053
        executor_class: type[Executor],
        log_stats: bool,
1054
        client_handshake_address: str | None = None,
1055
    ):
1056
1057
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1058
        self.step_counter = 0
1059
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1060
        self.last_counts = (0, 0)
1061
1062
1063

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1064
1065
1066
1067
1068
1069
1070
1071
1072
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
            dp_rank,
        )
1073
1074
1075

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1076
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1077
        dp_size = vllm_config.parallel_config.data_parallel_size
1078
1079
1080
1081
1082
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1083
1084
1085
1086
1087
1088
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
1089
1090
1091
1092
            logger.debug(
                "Setting kv_transfer_config.engine_id to %s",
                vllm_config.kv_transfer_config.engine_id,
            )
1093

1094
        self.dp_rank = dp_rank
1095
1096
1097
1098
1099
1100
1101
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1102
1103
1104
1105
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1106
1107
1108
1109
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1110
1111
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1112

1113
        super().add_request(request, request_wave)
1114

1115
1116
1117
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1118
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1119
1120
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1121
1122
                new_wave >= self.current_wave
            ):
1123
1124
                self.current_wave = new_wave
                if not self.engines_running:
1125
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1126
1127
1128
1129
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1130
    def _maybe_publish_request_counts(self):
1131
        if not self.publish_dp_lb_stats:
1132
1133
1134
1135
1136
1137
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1138
1139
1140
1141
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1142

1143
1144
1145
1146
1147
1148
1149
1150
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1151
1152
            # 2) Step the engine core.
            executed = self._process_engine_step()
1153
1154
            self._maybe_publish_request_counts()

1155
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1156
1157
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1158
1159
1160
                    # All engines are idle.
                    continue

1161
1162
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1163
1164
1165
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1166
            self.engines_running = self._has_global_unfinished_reqs(
1167
1168
                local_unfinished_reqs
            )
1169

1170
            if not self.engines_running:
1171
                if self.dp_rank == 0 or not self.has_coordinator:
1172
                    # Notify client that we are pausing the loop.
1173
1174
1175
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1176
1177
1178
1179
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1180
                    self.output_queue.put_nowait(
1181
1182
1183
1184
1185
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1186
                # Increment wave count and reset step counter.
1187
                self.current_wave += 1
1188
                self.step_counter = 0
1189
1190

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1191
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1192
1193
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1194
1195
            return True

1196
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1197

1198
    def reinitialize_distributed(
1199
1200
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1201
1202
1203
1204
1205
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1206
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1207
        if reconfig_request.new_data_parallel_rank != -1:
1208
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1209
        # local rank specifies device visibility, it should not be changed
1210
1211
1212
1213
1214
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1215
            reconfig_request.new_data_parallel_master_ip
1216
1217
        )
        parallel_config.data_parallel_master_port = (
1218
            reconfig_request.new_data_parallel_master_port
1219
        )
1220
1221
1222
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1223
        reconfig_request.new_data_parallel_master_port = (
1224
            parallel_config.data_parallel_master_port
1225
        )
1226
1227
1228
1229
1230
1231
1232
1233

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1234
1235
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1236
1237
1238
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1239
1240
1241
1242
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1243
1244
1245
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1246
1247
1248
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1249

Rui Qiao's avatar
Rui Qiao committed
1250
1251
1252
1253
1254
1255
1256
1257
1258

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1259
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1260
1261
1262
1263
1264
1265
1266
1267
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
1268
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1269

1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1280
1281
1282
1283
1284
1285
1286
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1287
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1288

1289
        super().__init__(vllm_config, local_client, "", executor_class, log_stats)
Rui Qiao's avatar
Rui Qiao committed
1290

1291
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1292
        from vllm.platforms import current_platform
1293

1294
1295
1296
1297
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1298
1299
1300
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1301

1302
1303
1304
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1305
1306
1307
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1308
1309
1310
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1311
            os.environ[device_control_env_var] = value
1312
1313
1314
1315
1316
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1317
1318
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1319

Rui Qiao's avatar
Rui Qiao committed
1320
    @contextmanager
1321
1322
1323
1324
1325
1326
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1327
        client_handshake_address: str | None,
1328
    ):
Rui Qiao's avatar
Rui Qiao committed
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()