core.py 52.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import gc
4
import os
5
import queue
6
import signal
7
8
import threading
import time
9
from collections import deque
10
from collections.abc import Callable, Generator
11
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
12
from contextlib import ExitStack, contextmanager
13
from inspect import isclass, signature
14
from logging import DEBUG
15
from typing import Any, TypeVar
16

17
import msgspec
18
19
import zmq

20
21
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
22
from vllm.envs import enable_envs_cache
23
from vllm.logger import init_logger
24
from vllm.logging_utils.dump_input import dump_engine_exception
25
from vllm.lora.request import LoRARequest
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.cache import engine_receiver_cache_from_config
28
from vllm.tasks import POOLING_TASKS, SupportedTask
29
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
30
from vllm.utils.gc_utils import maybe_attach_gc_debug_callback
31
from vllm.utils.hashing import get_hash_fn_by_name
32
from vllm.utils.import_utils import resolve_obj_by_qualname
33
from vllm.utils.network_utils import make_zmq_socket
34
from vllm.utils.system_utils import decorate_logs, set_process_title
35
36
37
38
39
40
41
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
42
from vllm.v1.core.sched.interface import SchedulerInterface
43
44
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from vllm.v1.engine import (
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
59
from vllm.v1.executor import Executor
60
from vllm.v1.kv_cache_interface import KVCacheConfig
61
from vllm.v1.metrics.stats import SchedulerStats
62
from vllm.v1.outputs import ModelRunnerOutput
63
from vllm.v1.request import Request, RequestStatus
64
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
65
from vllm.v1.structured_output import StructuredOutputManager
66
67
68
69
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

70
POLLING_TIMEOUT_S = 2.5
71
HANDSHAKE_TIMEOUT_MINS = 5
72

73
_R = TypeVar("_R")  # Return type for collective_rpc
74

75
76
77
78

class EngineCore:
    """Inner loop of vLLM's Engine."""

79
80
81
82
83
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
84
        executor_fail_callback: Callable | None = None,
85
    ):
86
87
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
88

89
90
        load_general_plugins()

91
        self.vllm_config = vllm_config
92
        if vllm_config.parallel_config.data_parallel_rank == 0:
93
94
95
96
97
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
98

99
100
        self.log_stats = log_stats

101
102
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
103
        if executor_fail_callback is not None:
104
            self.model_executor.register_failure_callback(executor_fail_callback)
105

106
107
        self.available_gpu_memory_for_kv_cache = -1

108
        # Setup KV Caches and update CacheConfig after profiling.
109
110
111
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
112

113
114
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
115
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
116

117
118
        self.structured_output_manager = StructuredOutputManager(vllm_config)

119
        # Setup scheduler.
120
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
121
            Scheduler = resolve_obj_by_qualname(
122
123
                vllm_config.scheduler_config.scheduler_cls
            )
124
125
126
127
128
129
130
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
131
132
133
134
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
135
136
                vllm_config.scheduler_config.scheduler_cls,
            )
137

138
139
140
141
142
143
        if len(kv_cache_config.kv_cache_groups) == 0:
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
            logger.info("Disabling chunked prefill for model without KVCache")
            vllm_config.scheduler_config.chunked_prefill_enabled = False

144
145
146
147
148
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
        )

149
        self.scheduler: SchedulerInterface = Scheduler(
150
            vllm_config=vllm_config,
151
152
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
153
            include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
154
            log_stats=self.log_stats,
155
            block_size=scheduler_block_size,
156
        )
157
        self.use_spec_decode = vllm_config.speculative_config is not None
158
        if self.scheduler.connector is not None:  # type: ignore
159
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
160

161
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
162
        self.mm_receiver_cache = engine_receiver_cache_from_config(
163
164
            vllm_config, mm_registry
        )
165

166
167
168
169
170
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
171
172
173
        self.batch_queue: (
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None
        ) = None
174
        if self.batch_queue_size > 1:
175
            logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
176
            self.batch_queue = deque(maxlen=self.batch_queue_size)
177

178
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
179
180
181
182
        if (
            self.vllm_config.cache_config.enable_prefix_caching
            or self.scheduler.get_kv_connector() is not None
        ):
183
            caching_hash_fn = get_hash_fn_by_name(
184
185
                vllm_config.cache_config.prefix_caching_hash_algo
            )
186
187
188
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
189
                scheduler_block_size, caching_hash_fn
190
            )
191

192
193
194
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
195

196
    def _initialize_kv_caches(
197
198
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
199
        start = time.time()
200

201
        # Get all kv cache needed by the model
202
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
203

204
205
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
206
207
208
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
209
                self.available_gpu_memory_for_kv_cache = (
210
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
211
212
213
214
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
215
216
217
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
218
219
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
220
221
222
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
223

224
        assert len(kv_cache_specs) == len(available_gpu_memory)
225

226
227
228
229
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
230
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
231
        num_cpu_blocks = 0
232
233

        # Initialize kv cache and warmup the execution
234
        self.model_executor.initialize_from_config(kv_cache_configs)
235

236
        elapsed = time.time() - start
237
        logger.info_once(
238
239
            ("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
            elapsed,
240
            scope="local",
241
        )
242
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
243

244
245
246
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

247
248
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
249

250
251
252
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
253
254
255
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
256
257
                f"request_id must be a string, got {type(request.request_id)}"
            )
258

259
        if pooling_params := request.pooling_params:
260
            supported_pooling_tasks = [
261
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
262
263
            ]

264
            if pooling_params.task not in supported_pooling_tasks:
265
266
267
268
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
269

270
        if request.kv_transfer_params is not None and (
271
272
273
274
275
276
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
277

278
        self.scheduler.add_request(request)
279

280
    def abort_requests(self, request_ids: list[str]):
281
282
283
284
285
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
286
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
287

288
289
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
290
        """Execute the model and log detailed info on failure."""
291
        try:
292
            yield
293
294
295
296
297
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

298
            # NOTE: This method is exception-free
299
300
301
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
302
303
            raise err

304
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
305
306
307
308
309
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
310

311
312
313
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
314
            return {}, False
315
        scheduler_output = self.scheduler.schedule()
316
317
318
319

        with self.log_error_detail(scheduler_output):
            model_output = self.model_executor.execute_model(scheduler_output)

320
        engine_core_outputs = self.scheduler.update_from_output(
321
            scheduler_output, model_output
322
        )
323

324
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
325

326
327
328
329
330
331
332
    def post_step(self, model_executed: bool) -> None:
        if self.use_spec_decode and model_executed:
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

333
    def step_with_batch_queue(
334
        self,
335
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
336
337
338
339
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
340
341
342
343
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
344
345
346
347
348
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
349
350
        batch_queue = self.batch_queue
        assert batch_queue is not None
351

352
353
354
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
355
        assert len(batch_queue) < self.batch_queue_size
356

357
358
359
        model_executed = False
        if self.scheduler.has_requests():
            scheduler_output = self.scheduler.schedule()
360
            future = self.model_executor.execute_model(scheduler_output, non_block=True)
361
            batch_queue.appendleft((future, scheduler_output))
362
363

            model_executed = scheduler_output.total_num_scheduled_tokens > 0
364
365
366
367
368
            if (
                model_executed
                and len(batch_queue) < self.batch_queue_size
                and not batch_queue[-1][0].done()
            ):
369
370
371
372
373
374
375
376
377
378
379
380
                # Don't block on next worker response unless the queue is full
                # or there are no more requests to schedule.
                return None, True

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False

        # Block until the next result is available.
        future, scheduler_output = batch_queue.pop()
381
382
        with self.log_error_detail(scheduler_output):
            model_output = future.result()
383

384
        engine_core_outputs = self.scheduler.update_from_output(
385
386
            scheduler_output, model_output
        )
387
        return engine_core_outputs, model_executed
388

389
    def shutdown(self):
390
        self.structured_output_manager.clear_backend()
391
392
        if self.model_executor:
            self.model_executor.shutdown()
393
394
        if self.scheduler:
            self.scheduler.shutdown()
395

396
    def profile(self, is_start: bool = True):
397
        self.model_executor.profile(is_start)
398

399
400
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
401
        # re-sync the internal caches (P0 sender, P1 receiver)
402
        if self.scheduler.has_unfinished_requests():
403
404
405
406
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
407

408
        # The cache either exists in EngineCore or WorkerWrapperBase
409
410
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
411

412
413
        self.model_executor.reset_mm_cache()

414
415
416
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

417
418
419
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

420
    def wake_up(self, tags: list[str] | None = None):
421
        self.model_executor.wake_up(tags)
422

423
424
425
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

426
    def execute_dummy_batch(self):
427
        self.model_executor.execute_dummy_batch()
428

429
430
431
432
433
434
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

435
    def list_loras(self) -> set[int]:
436
437
438
439
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
440

441
442
443
    def save_sharded_state(
        self,
        path: str,
444
445
        pattern: str | None = None,
        max_size: int | None = None,
446
    ) -> None:
447
448
449
450
451
452
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
453
454
        method: str | Callable[..., _R],
        timeout: float | None = None,
455
        args: tuple = (),
456
        kwargs: dict[str, Any] | None = None,
457
458
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
459

460
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
461
        """Preprocess the request.
462

463
464
465
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
466
467
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
468
        # and will only be accessed in the input processing thread afterwards.
469
        if self.mm_receiver_cache is not None and request.mm_features:
470
471
472
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
473

474
        req = Request.from_engine_core_request(request, self.request_block_hasher)
475
476
477
478
479
480
481
482
483
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

484
485
486
487

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

488
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
489

490
491
    def __init__(
        self,
492
        vllm_config: VllmConfig,
493
        local_client: bool,
494
        handshake_address: str,
495
        executor_class: type[Executor],
496
        log_stats: bool,
497
        client_handshake_address: str | None = None,
498
        engine_index: int = 0,
499
    ):
Rui Qiao's avatar
Rui Qiao committed
500
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
501
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
502
        executor_fail_callback = lambda: self.input_queue.put_nowait(
503
504
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
505

Rui Qiao's avatar
Rui Qiao committed
506
507
508
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
509

510
511
512
513
514
515
516
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
517
            self.client_count = len(addresses.outputs)
518
519

            # Set up data parallel environment.
520
            self.has_coordinator = addresses.coordinator_output is not None
521
            self.frontend_stats_publish_address = (
522
523
524
525
526
527
528
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
529
            # Only publish request queue stats to coordinator for "internal"
530
            # and "hybrid" LB modes .
531
532
            self.publish_dp_lb_stats = (
                self.has_coordinator
533
534
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
535

536
537
            self._init_data_parallel(vllm_config)

538
539
540
            super().__init__(
                vllm_config, executor_class, log_stats, executor_fail_callback
            )
541

542
543
544
545
546
547
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
548
549
550
551
552
553
554
555
556
557
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
558
559
560
561
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
562
563
564
565
566
567
568
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
569
570
571
572
573
574
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
575
                    raise RuntimeError("Input socket thread died during startup")
576
577
578
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

579
580
581
582
583
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        gc.collect()
        gc.freeze()

584
585
586
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()

587
588
589
590
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()

Rui Qiao's avatar
Rui Qiao committed
591
    @contextmanager
592
593
594
595
596
597
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
598
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
599
    ) -> Generator[EngineZmqAddresses, None, None]:
600
601
602
603
604
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

605
        For DP>1 with internal load-balancing this is with the shared front-end
606
607
        process which may reside on a different node.

608
        For DP>1 with external or hybrid load-balancing, two handshakes are
609
        performed:
610
611
612
613
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
614
615
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
616
617
618
619
620
621

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
622
        input_ctx = zmq.Context()
623
        is_local = local_client and client_handshake_address is None
624
        headless = not local_client
625
626
627
628
629
630
631
632
633
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
634
635
636
637
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
638
            assert local_client
639
            local_handshake = self._perform_handshake(
640
641
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
657
        headless: bool,
658
        vllm_config: VllmConfig,
659
        parallel_config_to_update: ParallelConfig | None = None,
660
    ) -> Generator[EngineZmqAddresses, None, None]:
661
662
663
664
665
666
667
668
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
669
            # Register engine with front-end.
670
671
672
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
673
674
675
676
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
677
678
679
680
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
681
682
683
684
685
686
687
688
689
690
691
692

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
693
                )
694
695

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
696

697
    @staticmethod
698
    def startup_handshake(
699
700
        handshake_socket: zmq.Socket,
        local_client: bool,
701
        headless: bool,
702
        parallel_config: ParallelConfig | None = None,
703
    ) -> EngineZmqAddresses:
704
        # Send registration message.
705
        handshake_socket.send(
706
707
708
709
710
711
712
713
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
714
715

        # Receive initialization message.
716
        logger.debug("Waiting for init message from front-end.")
717
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
718
719
720
721
722
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
723
724
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
725
726
            init_bytes, type=EngineHandshakeMetadata
        )
727
728
        logger.debug("Received init message: %s", init_message)

729
730
731
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
732

733
        return init_message.addresses
734
735

    @staticmethod
736
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
737
738
        """Launch EngineCore busy loop in background process."""

739
740
741
742
743
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

744
745
746
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

747
748
749
750
751
752
753
754
755
756
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

757
        engine_core: EngineCoreProc | None = None
758
        try:
759
            parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
760
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
761
                set_process_title("EngineCore", f"DP{dp_rank}")
762
                decorate_logs()
763
764
765
766
767
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
768
                set_process_title("EngineCore")
769
                decorate_logs()
770
771
                engine_core = EngineCoreProc(*args, **kwargs)

772
773
            engine_core.run_busy_loop()

774
        except SystemExit:
775
            logger.debug("EngineCore exiting.")
776
            raise
777
778
779
780
781
782
783
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
784
785
786
787
        finally:
            if engine_core is not None:
                engine_core.shutdown()

788
789
790
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

791
792
793
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

794
795
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
796
            # 1) Poll the input queue until there is work to do.
797
798
799
800
801
802
803
804
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
805
806
807
808
809
        while (
            not self.engines_running
            and not self.scheduler.has_requests()
            and not self.batch_queue
        ):
810
811
812
813
814
815
816
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
817
            logger.debug("EngineCore loop active.")
818
819
820
821
822
823

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

824
    def _process_engine_step(self) -> bool:
825
826
827
        """Called only when there are unfinished local requests."""

        # Step the engine core.
828
        outputs, model_executed = self.step_fn()
829
        # Put EngineCoreOutputs into the output queue.
830
        for output in outputs.items() if outputs else ():
831
            self.output_queue.put_nowait(output)
832
833
        # Post-step hook.
        self.post_step(model_executed)
834

835
836
        return model_executed

837
838
839
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
840
        """Dispatch request from client."""
841

842
        if request_type == EngineCoreRequestType.ADD:
843
844
            req, request_wave = request
            self.add_request(req, request_wave)
845
        elif request_type == EngineCoreRequestType.ABORT:
846
            self.abort_requests(request)
847
        elif request_type == EngineCoreRequestType.UTILITY:
848
            client_idx, call_id, method_name, args = request
849
850
851
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
852
853
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
854
855
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
856
857
858
                output.failure_message = (
                    f"Call to {method_name} method failed: {str(e)}"
                )
859
            self.output_queue.put_nowait(
860
861
                (client_idx, EngineCoreOutputs(utility_output=output))
            )
862
863
864
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
865
866
867
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
868
869
870
871

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
872
        arg type, try converting to msgspec object."""
873
874
875
876
877
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
878
879
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
880
            and issubclass(p.annotation, msgspec.Struct)
881
882
883
884
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
885

886
887
888
889
890
891
892
893
894
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
895
896
897
898
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
899

900
901
902
    def process_input_sockets(
        self,
        input_addresses: list[str],
903
        coord_input_address: str | None,
904
905
906
        identity: bytes,
        ready_event: threading.Event,
    ):
907
908
909
        """Input socket IO thread."""

        # Msgpack serialization decoding.
910
911
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
912

913
914
915
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
916
917
918
919
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
920
921
922
923
924
925
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
926
927
928
929
930
931
932
933
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
934
                # Send subscription message to coordinator.
935
                coord_socket.send(b"\x01")
936
937
938
939
940
941
942

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
943
                input_socket.send(b"")
944
                poller.register(input_socket, zmq.POLLIN)
945

946
            if coord_socket is not None:
947
948
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
949
                poller.register(coord_socket, zmq.POLLIN)
950

951
952
            ready_event.set()
            del ready_event
953
954
955
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
956
957
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
958
959

                    # Deserialize the request data.
960
961
962
963
964
                    if request_type == EngineCoreRequestType.ADD:
                        request = add_request_decoder.decode(data_frames)
                        request = self.preprocess_add_request(request)
                    else:
                        request = generic_decoder.decode(data_frames)
965
966
967
968

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

969
970
971
    def process_output_sockets(
        self,
        output_paths: list[str],
972
        coord_output_path: str | None,
973
974
        engine_index: int,
    ):
975
976
977
        """Output socket IO thread."""

        # Msgpack serialization encoding.
978
        encoder = MsgpackEncoder()
979
980
981
982
983
984
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
985

986
987
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
988
989
990
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
991
992
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
993
994
                for output_path in output_paths
            ]
995
996
997
998
999
1000
1001
1002
1003
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1004
1005
            max_reuse_bufs = len(sockets) + 1

1006
            while True:
1007
1008
1009
1010
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1011
                    break
1012
1013
                assert not isinstance(output, bytes)
                client_index, outputs = output
1014
                outputs.engine_index = engine_index
1015

1016
1017
1018
1019
1020
1021
1022
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1023
1024
1025
1026
1027
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1028
                buffers = encoder.encode_into(outputs, buffer)
1029
1030
1031
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1032
1033
1034
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1035
1036
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1037
                    reuse_buffers.append(buffer)
1038
1039
1040
1041
1042
1043
1044
1045
1046


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1047
        local_client: bool,
1048
        handshake_address: str,
1049
1050
        executor_class: type[Executor],
        log_stats: bool,
1051
        client_handshake_address: str | None = None,
1052
    ):
1053
1054
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1055
        self.step_counter = 0
1056
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1057
        self.last_counts = (0, 0)
1058
1059
1060

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1061
1062
1063
1064
1065
1066
1067
1068
1069
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
            dp_rank,
        )
1070
1071
1072

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1073
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1074
        dp_size = vllm_config.parallel_config.data_parallel_size
1075
1076
1077
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
1078
        assert local_dp_rank is not None
1079
1080
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1081
1082
1083
1084
1085
1086
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
1087
1088
1089
1090
            logger.debug(
                "Setting kv_transfer_config.engine_id to %s",
                vllm_config.kv_transfer_config.engine_id,
            )
1091

1092
        self.dp_rank = dp_rank
1093
1094
1095
1096
1097
1098
1099
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1100
1101
1102
1103
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1104
1105
1106
1107
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1108
1109
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1110

1111
        super().add_request(request, request_wave)
1112

1113
1114
1115
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1116
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1117
1118
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1119
1120
                new_wave >= self.current_wave
            ):
1121
1122
                self.current_wave = new_wave
                if not self.engines_running:
1123
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1124
1125
1126
1127
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1128
    def _maybe_publish_request_counts(self):
1129
        if not self.publish_dp_lb_stats:
1130
1131
1132
1133
1134
1135
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1136
1137
1138
1139
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1140

1141
1142
1143
1144
1145
1146
1147
1148
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1149
1150
            # 2) Step the engine core.
            executed = self._process_engine_step()
1151
1152
            self._maybe_publish_request_counts()

1153
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1154
1155
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1156
1157
1158
                    # All engines are idle.
                    continue

1159
1160
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1161
1162
1163
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1164
            self.engines_running = self._has_global_unfinished_reqs(
1165
1166
                local_unfinished_reqs
            )
1167

1168
            if not self.engines_running:
1169
                if self.dp_rank == 0 or not self.has_coordinator:
1170
                    # Notify client that we are pausing the loop.
1171
1172
1173
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1174
1175
1176
1177
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1178
                    self.output_queue.put_nowait(
1179
1180
1181
1182
1183
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1184
                # Increment wave count and reset step counter.
1185
                self.current_wave += 1
1186
                self.step_counter = 0
1187
1188

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1189
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1190
1191
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1192
1193
            return True

1194
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1195

1196
    def reinitialize_distributed(
1197
1198
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1199
1200
1201
1202
1203
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1204
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1205
        if reconfig_request.new_data_parallel_rank != -1:
1206
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1207
        # local rank specifies device visibility, it should not be changed
1208
1209
1210
1211
1212
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1213
            reconfig_request.new_data_parallel_master_ip
1214
1215
        )
        parallel_config.data_parallel_master_port = (
1216
            reconfig_request.new_data_parallel_master_port
1217
        )
1218
1219
1220
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1221
        reconfig_request.new_data_parallel_master_port = (
1222
            parallel_config.data_parallel_master_port
1223
        )
1224
1225
1226
1227
1228
1229
1230
1231

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1232
1233
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1234
1235
1236
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1237
1238
1239
1240
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1241
1242
1243
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1244
1245
1246
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1247

Rui Qiao's avatar
Rui Qiao committed
1248
1249
1250
1251
1252
1253
1254
1255
1256

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1257
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1258
1259
1260
1261
1262
1263
1264
1265
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
1266
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1267

1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1278
1279
1280
1281
1282
1283
1284
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1285
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1286

1287
        super().__init__(vllm_config, local_client, "", executor_class, log_stats)
Rui Qiao's avatar
Rui Qiao committed
1288

1289
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1290
        from vllm.platforms import current_platform
1291

1292
1293
1294
1295
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1296
1297
1298
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1299

1300
1301
1302
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1303
1304
1305
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1306
1307
1308
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1309
            os.environ[device_control_env_var] = value
1310
1311
1312
1313
1314
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1315
1316
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1317

Rui Qiao's avatar
Rui Qiao committed
1318
    @contextmanager
1319
1320
1321
1322
1323
1324
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1325
        client_handshake_address: str | None,
1326
    ):
Rui Qiao's avatar
Rui Qiao committed
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()