core.py 52.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import gc
4
import os
5
import queue
6
import signal
7
8
import threading
import time
9
from collections import deque
10
from collections.abc import Callable, Generator
11
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
12
from contextlib import ExitStack, contextmanager
13
from inspect import isclass, signature
14
from logging import DEBUG
15
from typing import Any, TypeVar
16

17
import msgspec
18
19
import zmq

20
21
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
22
from vllm.distributed.parallel_state import is_global_first_rank
23
from vllm.envs import enable_envs_cache
24
from vllm.logger import init_logger
25
from vllm.logging_utils.dump_input import dump_engine_exception
26
from vllm.lora.request import LoRARequest
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.cache import engine_receiver_cache_from_config
29
from vllm.tasks import POOLING_TASKS, SupportedTask
30
31
32
33
34
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils import (
    decorate_logs,
    set_process_title,
)
35
from vllm.utils.gc_utils import maybe_attach_gc_debug_callback
36
from vllm.utils.hashing import get_hash_fn_by_name
37
from vllm.utils.import_utils import resolve_obj_by_qualname
38
from vllm.utils.network_utils import make_zmq_socket
39
40
41
42
43
44
45
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
46
from vllm.v1.core.sched.interface import SchedulerInterface
47
48
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from vllm.v1.engine import (
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
63
from vllm.v1.executor import Executor
64
from vllm.v1.kv_cache_interface import KVCacheConfig
65
from vllm.v1.metrics.stats import SchedulerStats
66
from vllm.v1.outputs import ModelRunnerOutput
67
from vllm.v1.request import Request, RequestStatus
68
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
69
from vllm.v1.structured_output import StructuredOutputManager
70
71
72
73
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

74
POLLING_TIMEOUT_S = 2.5
75
HANDSHAKE_TIMEOUT_MINS = 5
76

77
_R = TypeVar("_R")  # Return type for collective_rpc
78

79
80
81
82

class EngineCore:
    """Inner loop of vLLM's Engine."""

83
84
85
86
87
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
88
        executor_fail_callback: Callable | None = None,
89
    ):
90
91
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
92

93
94
        load_general_plugins()

95
        self.vllm_config = vllm_config
96
97
98
99
100
101
        if is_global_first_rank():
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
102

103
104
        self.log_stats = log_stats

105
106
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
107
        if executor_fail_callback is not None:
108
            self.model_executor.register_failure_callback(executor_fail_callback)
109

110
111
        self.available_gpu_memory_for_kv_cache = -1

112
        # Setup KV Caches and update CacheConfig after profiling.
113
114
115
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
116

117
118
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
119
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
120

121
122
        self.structured_output_manager = StructuredOutputManager(vllm_config)

123
        # Setup scheduler.
124
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
125
            Scheduler = resolve_obj_by_qualname(
126
127
                vllm_config.scheduler_config.scheduler_cls
            )
128
129
130
131
132
133
134
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
135
136
137
138
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
139
140
                vllm_config.scheduler_config.scheduler_cls,
            )
141

142
143
144
145
146
147
        if len(kv_cache_config.kv_cache_groups) == 0:
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
            logger.info("Disabling chunked prefill for model without KVCache")
            vllm_config.scheduler_config.chunked_prefill_enabled = False

148
149
150
151
152
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
        )

153
        self.scheduler: SchedulerInterface = Scheduler(
154
            vllm_config=vllm_config,
155
156
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
157
            include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
158
            log_stats=self.log_stats,
159
            block_size=scheduler_block_size,
160
        )
161
        self.use_spec_decode = vllm_config.speculative_config is not None
162
163
        if self.scheduler.connector is not None:  # type: ignore
            self.model_executor.init_kv_output_aggregator(
164
165
                self.scheduler.connector.get_finished_count()  # type: ignore
            )
166

167
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
168
        self.mm_receiver_cache = engine_receiver_cache_from_config(
169
170
            vllm_config, mm_registry
        )
171

172
173
174
175
176
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
177
178
179
        self.batch_queue: (
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None
        ) = None
180
        if self.batch_queue_size > 1:
181
            logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
182
            self.batch_queue = deque(maxlen=self.batch_queue_size)
183

184
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
185
186
187
188
        if (
            self.vllm_config.cache_config.enable_prefix_caching
            or self.scheduler.get_kv_connector() is not None
        ):
189
            caching_hash_fn = get_hash_fn_by_name(
190
191
                vllm_config.cache_config.prefix_caching_hash_algo
            )
192
193
194
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
195
                scheduler_block_size, caching_hash_fn
196
            )
197

198
199
200
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
201

202
    def _initialize_kv_caches(
203
204
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
205
        start = time.time()
206

207
        # Get all kv cache needed by the model
208
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
209

210
211
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
212
213
214
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
215
                self.available_gpu_memory_for_kv_cache = (
216
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
217
218
219
220
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
221
222
223
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
224
225
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
226
227
228
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
229

230
        assert len(kv_cache_specs) == len(available_gpu_memory)
231

232
233
234
235
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
236
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
237
        num_cpu_blocks = 0
238
239

        # Initialize kv cache and warmup the execution
240
        self.model_executor.initialize_from_config(kv_cache_configs)
241

242
        elapsed = time.time() - start
243
244
245
246
        logger.info(
            ("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
            elapsed,
        )
247
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
248

249
250
251
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

252
253
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
254

255
256
257
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
258
259
260
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
261
262
                f"request_id must be a string, got {type(request.request_id)}"
            )
263

264
        if pooling_params := request.pooling_params:
265
            supported_pooling_tasks = [
266
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
267
268
            ]

269
            if pooling_params.task not in supported_pooling_tasks:
270
271
272
273
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
274

275
        if request.kv_transfer_params is not None and (
276
277
278
279
280
281
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
282

283
        self.scheduler.add_request(request)
284

285
    def abort_requests(self, request_ids: list[str]):
286
287
288
289
290
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
291
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
292

293
294
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
295
        """Execute the model and log detailed info on failure."""
296
        try:
297
            yield
298
299
300
301
302
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

303
            # NOTE: This method is exception-free
304
305
306
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
307
308
            raise err

309
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
310
311
312
313
314
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
315

316
317
318
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
319
            return {}, False
320
        scheduler_output = self.scheduler.schedule()
321
322
323
324

        with self.log_error_detail(scheduler_output):
            model_output = self.model_executor.execute_model(scheduler_output)

325
        engine_core_outputs = self.scheduler.update_from_output(
326
            scheduler_output, model_output
327
        )
328

329
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
330

331
332
333
334
335
336
337
    def post_step(self, model_executed: bool) -> None:
        if self.use_spec_decode and model_executed:
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

338
    def step_with_batch_queue(
339
        self,
340
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
341
342
343
344
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
345
346
347
348
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
349
350
351
352
353
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
354
355
        batch_queue = self.batch_queue
        assert batch_queue is not None
356

357
358
359
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
360
        assert len(batch_queue) < self.batch_queue_size
361

362
363
364
        model_executed = False
        if self.scheduler.has_requests():
            scheduler_output = self.scheduler.schedule()
365
            future = self.model_executor.execute_model(scheduler_output, non_block=True)
366
            batch_queue.appendleft((future, scheduler_output))
367
368

            model_executed = scheduler_output.total_num_scheduled_tokens > 0
369
370
371
372
373
            if (
                model_executed
                and len(batch_queue) < self.batch_queue_size
                and not batch_queue[-1][0].done()
            ):
374
375
376
377
378
379
380
381
382
383
384
385
                # Don't block on next worker response unless the queue is full
                # or there are no more requests to schedule.
                return None, True

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False

        # Block until the next result is available.
        future, scheduler_output = batch_queue.pop()
386
387
        with self.log_error_detail(scheduler_output):
            model_output = future.result()
388

389
        engine_core_outputs = self.scheduler.update_from_output(
390
391
            scheduler_output, model_output
        )
392
        return engine_core_outputs, model_executed
393

394
    def shutdown(self):
395
        self.structured_output_manager.clear_backend()
396
397
        if self.model_executor:
            self.model_executor.shutdown()
398
399
        if self.scheduler:
            self.scheduler.shutdown()
400

401
    def profile(self, is_start: bool = True):
402
        self.model_executor.profile(is_start)
403

404
405
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
406
        # re-sync the internal caches (P0 sender, P1 receiver)
407
        if self.scheduler.has_unfinished_requests():
408
409
410
411
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
412

413
        # The cache either exists in EngineCore or WorkerWrapperBase
414
415
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
416

417
418
        self.model_executor.reset_mm_cache()

419
420
421
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

422
423
424
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

425
    def wake_up(self, tags: list[str] | None = None):
426
        self.model_executor.wake_up(tags)
427

428
429
430
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

431
    def execute_dummy_batch(self):
432
        self.model_executor.execute_dummy_batch()
433

434
435
436
437
438
439
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

440
    def list_loras(self) -> set[int]:
441
442
443
444
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
445

446
447
448
    def save_sharded_state(
        self,
        path: str,
449
450
        pattern: str | None = None,
        max_size: int | None = None,
451
    ) -> None:
452
453
454
455
456
457
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
458
459
        method: str | Callable[..., _R],
        timeout: float | None = None,
460
        args: tuple = (),
461
        kwargs: dict[str, Any] | None = None,
462
463
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
464

465
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
466
        """Preprocess the request.
467

468
469
470
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
471
472
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
473
        # and will only be accessed in the input processing thread afterwards.
474
        if self.mm_receiver_cache is not None and request.mm_features:
475
476
477
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
478

479
        req = Request.from_engine_core_request(request, self.request_block_hasher)
480
481
482
483
484
485
486
487
488
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

489
490
491
492

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

493
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
494

495
496
    def __init__(
        self,
497
        vllm_config: VllmConfig,
498
        local_client: bool,
499
        handshake_address: str,
500
        executor_class: type[Executor],
501
        log_stats: bool,
502
        client_handshake_address: str | None = None,
503
        engine_index: int = 0,
504
    ):
Rui Qiao's avatar
Rui Qiao committed
505
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
506
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
507
        executor_fail_callback = lambda: self.input_queue.put_nowait(
508
509
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
510

Rui Qiao's avatar
Rui Qiao committed
511
512
513
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
514

515
516
517
518
519
520
521
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
522
            self.client_count = len(addresses.outputs)
523
524

            # Set up data parallel environment.
525
            self.has_coordinator = addresses.coordinator_output is not None
526
            self.frontend_stats_publish_address = (
527
528
529
530
531
532
533
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
534
            # Only publish request queue stats to coordinator for "internal"
535
            # and "hybrid" LB modes .
536
537
            self.publish_dp_lb_stats = (
                self.has_coordinator
538
539
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
540

541
542
            self._init_data_parallel(vllm_config)

543
544
545
            super().__init__(
                vllm_config, executor_class, log_stats, executor_fail_callback
            )
546

547
548
549
550
551
552
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
553
554
555
556
557
558
559
560
561
562
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
563
564
565
566
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
567
568
569
570
571
572
573
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
574
575
576
577
578
579
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
580
                    raise RuntimeError("Input socket thread died during startup")
581
582
583
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

584
585
586
587
588
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        gc.collect()
        gc.freeze()

589
590
591
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()

592
593
594
595
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()

Rui Qiao's avatar
Rui Qiao committed
596
    @contextmanager
597
598
599
600
601
602
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
603
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
604
    ) -> Generator[EngineZmqAddresses, None, None]:
605
606
607
608
609
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

610
        For DP>1 with internal load-balancing this is with the shared front-end
611
612
        process which may reside on a different node.

613
        For DP>1 with external or hybrid load-balancing, two handshakes are
614
        performed:
615
616
617
618
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
619
620
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
621
622
623
624
625
626

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
627
        input_ctx = zmq.Context()
628
        is_local = local_client and client_handshake_address is None
629
        headless = not local_client
630
631
632
633
634
635
636
637
638
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
639
640
641
642
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
643
            assert local_client
644
            local_handshake = self._perform_handshake(
645
646
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
662
        headless: bool,
663
        vllm_config: VllmConfig,
664
        parallel_config_to_update: ParallelConfig | None = None,
665
    ) -> Generator[EngineZmqAddresses, None, None]:
666
667
668
669
670
671
672
673
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
674
            # Register engine with front-end.
675
676
677
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
678
679
680
681
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
682
683
684
685
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
686
687
688
689
690
691
692
693
694
695
696
697

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
698
                )
699
700

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
701

702
    @staticmethod
703
    def startup_handshake(
704
705
        handshake_socket: zmq.Socket,
        local_client: bool,
706
        headless: bool,
707
        parallel_config: ParallelConfig | None = None,
708
    ) -> EngineZmqAddresses:
709
        # Send registration message.
710
        handshake_socket.send(
711
712
713
714
715
716
717
718
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
719
720
721

        # Receive initialization message.
        logger.info("Waiting for init message from front-end.")
722
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
723
724
725
726
727
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
728
729
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
730
731
            init_bytes, type=EngineHandshakeMetadata
        )
732
733
        logger.debug("Received init message: %s", init_message)

734
735
736
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
737

738
        return init_message.addresses
739
740

    @staticmethod
741
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
742
743
        """Launch EngineCore busy loop in background process."""

744
745
746
747
748
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

749
750
751
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

752
753
754
755
756
757
758
759
760
761
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

762
        engine_core: EngineCoreProc | None = None
763
        try:
764
            parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
765
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
766
                set_process_title("EngineCore", f"DP{dp_rank}")
767
                decorate_logs()
768
769
770
771
772
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
773
                set_process_title("EngineCore")
774
                decorate_logs()
775
776
                engine_core = EngineCoreProc(*args, **kwargs)

777
778
            engine_core.run_busy_loop()

779
        except SystemExit:
780
            logger.debug("EngineCore exiting.")
781
            raise
782
783
784
785
786
787
788
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
789
790
791
792
        finally:
            if engine_core is not None:
                engine_core.shutdown()

793
794
795
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

796
797
798
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

799
800
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
801
            # 1) Poll the input queue until there is work to do.
802
803
804
805
806
807
808
809
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
810
811
812
813
814
        while (
            not self.engines_running
            and not self.scheduler.has_requests()
            and not self.batch_queue
        ):
815
816
817
818
819
820
821
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
822
            logger.debug("EngineCore loop active.")
823
824
825
826
827
828

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

829
    def _process_engine_step(self) -> bool:
830
831
832
        """Called only when there are unfinished local requests."""

        # Step the engine core.
833
        outputs, model_executed = self.step_fn()
834
        # Put EngineCoreOutputs into the output queue.
835
        for output in outputs.items() if outputs else ():
836
            self.output_queue.put_nowait(output)
837
838
        # Post-step hook.
        self.post_step(model_executed)
839

840
841
        return model_executed

842
843
844
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
845
        """Dispatch request from client."""
846

847
        if request_type == EngineCoreRequestType.ADD:
848
849
            req, request_wave = request
            self.add_request(req, request_wave)
850
        elif request_type == EngineCoreRequestType.ABORT:
851
            self.abort_requests(request)
852
        elif request_type == EngineCoreRequestType.UTILITY:
853
            client_idx, call_id, method_name, args = request
854
855
856
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
857
858
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
859
860
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
861
862
863
                output.failure_message = (
                    f"Call to {method_name} method failed: {str(e)}"
                )
864
            self.output_queue.put_nowait(
865
866
                (client_idx, EngineCoreOutputs(utility_output=output))
            )
867
868
869
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
870
871
872
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
873
874
875
876

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
877
        arg type, try converting to msgspec object."""
878
879
880
881
882
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
883
884
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
885
            and issubclass(p.annotation, msgspec.Struct)
886
887
888
889
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
890

891
892
893
894
895
896
897
898
899
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
900
901
902
903
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
904

905
906
907
    def process_input_sockets(
        self,
        input_addresses: list[str],
908
        coord_input_address: str | None,
909
910
911
        identity: bytes,
        ready_event: threading.Event,
    ):
912
913
914
        """Input socket IO thread."""

        # Msgpack serialization decoding.
915
916
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
917

918
919
920
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
921
922
923
924
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
925
926
927
928
929
930
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
931
932
933
934
935
936
937
938
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
939
                # Send subscription message to coordinator.
940
                coord_socket.send(b"\x01")
941
942
943
944
945
946
947

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
948
                input_socket.send(b"")
949
                poller.register(input_socket, zmq.POLLIN)
950

951
            if coord_socket is not None:
952
953
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
954
                poller.register(coord_socket, zmq.POLLIN)
955

956
957
            ready_event.set()
            del ready_event
958
959
960
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
961
962
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
963
964

                    # Deserialize the request data.
965
966
967
968
969
                    if request_type == EngineCoreRequestType.ADD:
                        request = add_request_decoder.decode(data_frames)
                        request = self.preprocess_add_request(request)
                    else:
                        request = generic_decoder.decode(data_frames)
970
971
972
973

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

974
975
976
    def process_output_sockets(
        self,
        output_paths: list[str],
977
        coord_output_path: str | None,
978
979
        engine_index: int,
    ):
980
981
982
        """Output socket IO thread."""

        # Msgpack serialization encoding.
983
        encoder = MsgpackEncoder()
984
985
986
987
988
989
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
990

991
992
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
993
994
995
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
996
997
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
998
999
                for output_path in output_paths
            ]
1000
1001
1002
1003
1004
1005
1006
1007
1008
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1009
1010
            max_reuse_bufs = len(sockets) + 1

1011
            while True:
1012
1013
1014
1015
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1016
                    break
1017
1018
                assert not isinstance(output, bytes)
                client_index, outputs = output
1019
                outputs.engine_index = engine_index
1020

1021
1022
1023
1024
1025
1026
1027
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1028
1029
1030
1031
1032
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1033
                buffers = encoder.encode_into(outputs, buffer)
1034
1035
1036
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1037
1038
1039
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1040
1041
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1042
                    reuse_buffers.append(buffer)
1043
1044
1045
1046
1047
1048
1049
1050
1051


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1052
        local_client: bool,
1053
        handshake_address: str,
1054
1055
        executor_class: type[Executor],
        log_stats: bool,
1056
        client_handshake_address: str | None = None,
1057
    ):
1058
1059
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1060
        self.step_counter = 0
1061
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1062
        self.last_counts = (0, 0)
1063
1064
1065

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1066
1067
1068
1069
1070
1071
1072
1073
1074
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
            dp_rank,
        )
1075
1076
1077

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1078
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1079
        dp_size = vllm_config.parallel_config.data_parallel_size
1080
1081
1082
1083
1084
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1085
1086
1087
1088
1089
1090
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
1091
1092
1093
1094
            logger.debug(
                "Setting kv_transfer_config.engine_id to %s",
                vllm_config.kv_transfer_config.engine_id,
            )
1095

1096
        self.dp_rank = dp_rank
1097
1098
1099
1100
1101
1102
1103
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1104
1105
1106
1107
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1108
1109
1110
1111
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1112
1113
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1114

1115
        super().add_request(request, request_wave)
1116

1117
1118
1119
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1120
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1121
1122
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1123
1124
                new_wave >= self.current_wave
            ):
1125
1126
                self.current_wave = new_wave
                if not self.engines_running:
1127
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1128
1129
1130
1131
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1132
    def _maybe_publish_request_counts(self):
1133
        if not self.publish_dp_lb_stats:
1134
1135
1136
1137
1138
1139
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1140
1141
1142
1143
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1144

1145
1146
1147
1148
1149
1150
1151
1152
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1153
1154
            # 2) Step the engine core.
            executed = self._process_engine_step()
1155
1156
            self._maybe_publish_request_counts()

1157
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1158
1159
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1160
1161
1162
                    # All engines are idle.
                    continue

1163
1164
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1165
1166
1167
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1168
            self.engines_running = self._has_global_unfinished_reqs(
1169
1170
                local_unfinished_reqs
            )
1171

1172
            if not self.engines_running:
1173
                if self.dp_rank == 0 or not self.has_coordinator:
1174
                    # Notify client that we are pausing the loop.
1175
1176
1177
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1178
1179
1180
1181
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1182
                    self.output_queue.put_nowait(
1183
1184
1185
1186
1187
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1188
                # Increment wave count and reset step counter.
1189
                self.current_wave += 1
1190
                self.step_counter = 0
1191
1192

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1193
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1194
1195
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1196
1197
            return True

1198
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1199

1200
    def reinitialize_distributed(
1201
1202
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1203
1204
1205
1206
1207
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1208
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1209
        if reconfig_request.new_data_parallel_rank != -1:
1210
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1211
        # local rank specifies device visibility, it should not be changed
1212
1213
1214
1215
1216
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1217
            reconfig_request.new_data_parallel_master_ip
1218
1219
        )
        parallel_config.data_parallel_master_port = (
1220
            reconfig_request.new_data_parallel_master_port
1221
        )
1222
1223
1224
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1225
        reconfig_request.new_data_parallel_master_port = (
1226
            parallel_config.data_parallel_master_port
1227
        )
1228
1229
1230
1231
1232
1233
1234
1235

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1236
1237
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1238
1239
1240
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1241
1242
1243
1244
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1245
1246
1247
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1248
1249
1250
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1251

Rui Qiao's avatar
Rui Qiao committed
1252
1253
1254
1255
1256
1257
1258
1259
1260

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1261
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1262
1263
1264
1265
1266
1267
1268
1269
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
1270
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1271

1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1282
1283
1284
1285
1286
1287
1288
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1289
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1290

1291
        super().__init__(vllm_config, local_client, "", executor_class, log_stats)
Rui Qiao's avatar
Rui Qiao committed
1292

1293
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1294
        from vllm.platforms import current_platform
1295

1296
1297
1298
1299
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1300
1301
1302
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1303

1304
1305
1306
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1307
1308
1309
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1310
1311
1312
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1313
            os.environ[device_control_env_var] = value
1314
1315
1316
1317
1318
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1319
1320
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1321

Rui Qiao's avatar
Rui Qiao committed
1322
    @contextmanager
1323
1324
1325
1326
1327
1328
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1329
        client_handshake_address: str | None,
1330
    ):
Rui Qiao's avatar
Rui Qiao committed
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()