core.py 52.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import gc
4
import os
5
import queue
6
import signal
7
8
import threading
import time
9
from collections import deque
Rui Qiao's avatar
Rui Qiao committed
10
from collections.abc import Generator
11
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
12
from contextlib import ExitStack, contextmanager
13
from inspect import isclass, signature
14
from logging import DEBUG
15
from typing import Any, Callable, Optional, TypeVar, Union
16

17
import msgspec
18
19
import zmq

20
21
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
22
from vllm.logger import init_logger
23
from vllm.logging_utils.dump_input import dump_engine_exception
24
from vllm.lora.request import LoRARequest
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.multimodal.cache import engine_receiver_cache_from_config
27
from vllm.tasks import POOLING_TASKS, SupportedTask
28
29
30
31
32
33
34
35
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils import (
    decorate_logs,
    get_hash_fn_by_name,
    make_zmq_socket,
    resolve_obj_by_qualname,
    set_process_title,
)
36
from vllm.utils.gc_utils import maybe_attach_gc_debug_callback
37
38
39
40
41
42
43
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
44
from vllm.v1.core.sched.interface import SchedulerInterface
45
46
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from vllm.v1.engine import (
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
61
from vllm.v1.executor.abstract import Executor
62
from vllm.v1.kv_cache_interface import KVCacheConfig
63
from vllm.v1.metrics.stats import SchedulerStats
64
from vllm.v1.outputs import ModelRunnerOutput
65
from vllm.v1.request import Request, RequestStatus
66
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
67
from vllm.v1.structured_output import StructuredOutputManager
68
69
70
71
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

72
POLLING_TIMEOUT_S = 2.5
73
HANDSHAKE_TIMEOUT_MINS = 5
74

75
_R = TypeVar("_R")  # Return type for collective_rpc
76

77
78
79
80

class EngineCore:
    """Inner loop of vLLM's Engine."""

81
82
83
84
85
86
87
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
        executor_fail_callback: Optional[Callable] = None,
    ):
88
89
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
90

91
92
        load_general_plugins()

93
        self.vllm_config = vllm_config
94
95
96
97
98
        logger.info(
            "Initializing a V1 LLM engine (v%s) with config: %s",
            VLLM_VERSION,
            vllm_config,
        )
99

100
101
        self.log_stats = log_stats

102
103
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
104
        if executor_fail_callback is not None:
105
            self.model_executor.register_failure_callback(executor_fail_callback)
106

107
108
        self.available_gpu_memory_for_kv_cache = -1

109
        # Setup KV Caches and update CacheConfig after profiling.
110
111
112
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
113

114
115
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
116
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
117

118
119
        self.structured_output_manager = StructuredOutputManager(vllm_config)

120
        # Setup scheduler.
121
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
122
            Scheduler = resolve_obj_by_qualname(
123
124
                vllm_config.scheduler_config.scheduler_cls
            )
125
126
127
128
129
130
131
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
132
133
134
135
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
136
137
                vllm_config.scheduler_config.scheduler_cls,
            )
138

139
140
141
142
143
144
        if len(kv_cache_config.kv_cache_groups) == 0:
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
            logger.info("Disabling chunked prefill for model without KVCache")
            vllm_config.scheduler_config.chunked_prefill_enabled = False

145
146
147
148
149
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
        )

150
        self.scheduler: SchedulerInterface = Scheduler(
151
            vllm_config=vllm_config,
152
153
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
154
            include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
155
            log_stats=self.log_stats,
156
            block_size=scheduler_block_size,
157
        )
158
        self.use_spec_decode = vllm_config.speculative_config is not None
159
160
        if self.scheduler.connector is not None:  # type: ignore
            self.model_executor.init_kv_output_aggregator(
161
162
                self.scheduler.connector.get_finished_count()  # type: ignore
            )
163

164
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
165
        self.mm_receiver_cache = engine_receiver_cache_from_config(
166
167
            vllm_config, mm_registry
        )
168

169
170
171
172
173
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
174
175
176
        self.batch_queue: Optional[
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]]
        ] = None
177
        if self.batch_queue_size > 1:
178
            logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
179
            self.batch_queue = deque(maxlen=self.batch_queue_size)
180

181
182
183
184
185
        self.request_block_hasher: Optional[Callable[[Request], list[BlockHash]]] = None
        if (
            self.vllm_config.cache_config.enable_prefix_caching
            or self.scheduler.get_kv_connector() is not None
        ):
186
            caching_hash_fn = get_hash_fn_by_name(
187
188
                vllm_config.cache_config.prefix_caching_hash_algo
            )
189
190
191
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
192
                scheduler_block_size, caching_hash_fn
193
            )
194

195
196
197
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
198

199
    def _initialize_kv_caches(
200
201
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
202
        start = time.time()
203

204
        # Get all kv cache needed by the model
205
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
206

207
208
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
209
210
211
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
212
                self.available_gpu_memory_for_kv_cache = (
213
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
214
215
216
217
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
218
219
220
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
221
222
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
223
224
225
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
226

227
        assert len(kv_cache_specs) == len(available_gpu_memory)
228

229
230
231
232
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
233
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
234
        num_cpu_blocks = 0
235
236

        # Initialize kv cache and warmup the execution
237
        self.model_executor.initialize_from_config(kv_cache_configs)
238

239
        elapsed = time.time() - start
240
241
242
243
        logger.info(
            ("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
            elapsed,
        )
244
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
245

246
247
248
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

249
250
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
251

252
253
254
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
255
256
257
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
258
259
                f"request_id must be a string, got {type(request.request_id)}"
            )
260

261
        if pooling_params := request.pooling_params:
262
            supported_pooling_tasks = [
263
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
264
265
            ]

266
            if pooling_params.task not in supported_pooling_tasks:
267
268
269
270
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
271

272
        if request.kv_transfer_params is not None and (
273
274
275
276
277
278
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
279

280
        self.scheduler.add_request(request)
281

282
    def abort_requests(self, request_ids: list[str]):
283
284
285
286
287
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
288
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
289

290
291
292
293
294
295
    def execute_model_with_error_logging(
        self,
        model_fn: Callable[[SchedulerOutput], ModelRunnerOutput],
        scheduler_output: SchedulerOutput,
    ) -> ModelRunnerOutput:
        """Execute the model and log detailed info on failure."""
296
        try:
297
            return model_fn(scheduler_output)
298
299
300
301
302
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

303
            # NOTE: This method is exception-free
304
305
306
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
307
308
            raise err

309
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
310
311
312
313
314
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
315

316
317
318
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
319
            return {}, False
320
        scheduler_output = self.scheduler.schedule()
321
322
        model_output = self.execute_model_with_error_logging(
            self.model_executor.execute_model,  # type: ignore
323
324
            scheduler_output,
        )
325
        engine_core_outputs = self.scheduler.update_from_output(
326
            scheduler_output, model_output
327
        )
328

329
        return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0)
330

331
332
333
334
335
336
337
    def post_step(self, model_executed: bool) -> None:
        if self.use_spec_decode and model_executed:
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

338
    def step_with_batch_queue(
339
340
        self,
    ) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
341
342
343
344
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
345
346
347
348
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
349
350
351
352
353
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
354
355
        batch_queue = self.batch_queue
        assert batch_queue is not None
356

357
358
359
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
360
        assert len(batch_queue) < self.batch_queue_size
361

362
363
364
        model_executed = False
        if self.scheduler.has_requests():
            scheduler_output = self.scheduler.schedule()
365
366
            future = self.model_executor.execute_model(scheduler_output, non_block=True)
            batch_queue.appendleft((future, scheduler_output))  # type: ignore[arg-type]
367
368

            model_executed = scheduler_output.total_num_scheduled_tokens > 0
369
370
371
372
373
            if (
                model_executed
                and len(batch_queue) < self.batch_queue_size
                and not batch_queue[-1][0].done()
            ):
374
375
376
377
378
379
380
381
382
383
384
385
386
                # Don't block on next worker response unless the queue is full
                # or there are no more requests to schedule.
                return None, True

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False

        # Block until the next result is available.
        future, scheduler_output = batch_queue.pop()
        model_output = self.execute_model_with_error_logging(
387
388
            lambda _: future.result(), scheduler_output
        )
389

390
        engine_core_outputs = self.scheduler.update_from_output(
391
392
            scheduler_output, model_output
        )
393

394
        return engine_core_outputs, model_executed
395

396
    def shutdown(self):
397
        self.structured_output_manager.clear_backend()
398
399
        if self.model_executor:
            self.model_executor.shutdown()
400
401
        if self.scheduler:
            self.scheduler.shutdown()
402

403
    def profile(self, is_start: bool = True):
404
        self.model_executor.profile(is_start)
405

406
407
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
408
        # re-sync the internal caches (P0 sender, P1 receiver)
409
        if self.scheduler.has_unfinished_requests():
410
411
412
413
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
414

415
        # The cache either exists in EngineCore or WorkerWrapperBase
416
417
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
418

419
420
        self.model_executor.reset_mm_cache()

421
422
423
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

424
425
426
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

427
428
    def wake_up(self, tags: Optional[list[str]] = None):
        self.model_executor.wake_up(tags)
429

430
431
432
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

433
    def execute_dummy_batch(self):
434
        self.model_executor.execute_dummy_batch()
435

436
437
438
439
440
441
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

442
    def list_loras(self) -> set[int]:
443
444
445
446
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
447

448
449
450
451
452
453
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
454
455
456
457
458
459
460
461
462
463
464
465
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
        method: Union[str, Callable[..., _R]],
        timeout: Optional[float] = None,
        args: tuple = (),
        kwargs: Optional[dict[str, Any]] = None,
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
466

467
468
469
470
471
    def save_tensorized_model(
        self,
        tensorizer_config,
    ) -> None:
        self.model_executor.save_tensorized_model(
472
473
            tensorizer_config=tensorizer_config,
        )
474

475
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
476
        """Preprocess the request.
477

478
479
480
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
481
482
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
483
        # and will only be accessed in the input processing thread afterwards.
484
        if self.mm_receiver_cache is not None and request.mm_features:
485
486
487
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
488

489
        req = Request.from_engine_core_request(request, self.request_block_hasher)
490
491
492
493
494
495
496
497
498
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

499
500
501
502

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

503
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
504

505
506
    def __init__(
        self,
507
        vllm_config: VllmConfig,
508
        local_client: bool,
509
        handshake_address: str,
510
        executor_class: type[Executor],
511
        log_stats: bool,
512
        client_handshake_address: Optional[str] = None,
513
        engine_index: int = 0,
514
    ):
Rui Qiao's avatar
Rui Qiao committed
515
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
516
        self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], bytes]]()
Rui Qiao's avatar
Rui Qiao committed
517
        executor_fail_callback = lambda: self.input_queue.put_nowait(
518
519
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
520

Rui Qiao's avatar
Rui Qiao committed
521
522
523
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
524

525
526
527
528
529
530
531
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
532
            self.client_count = len(addresses.outputs)
533
534

            # Set up data parallel environment.
535
            self.has_coordinator = addresses.coordinator_output is not None
536
            self.frontend_stats_publish_address = (
537
538
539
540
541
542
543
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
544
            # Only publish request queue stats to coordinator for "internal"
545
            # and "hybrid" LB modes .
546
547
            self.publish_dp_lb_stats = (
                self.has_coordinator
548
549
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
550

551
552
            self._init_data_parallel(vllm_config)

553
554
555
            super().__init__(
                vllm_config, executor_class, log_stats, executor_fail_callback
            )
556

557
558
559
560
561
562
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
563
564
565
566
567
568
569
570
571
572
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
573
574
575
576
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
577
578
579
580
581
582
583
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
584
585
586
587
588
589
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
590
                    raise RuntimeError("Input socket thread died during startup")
591
592
593
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

594
595
596
597
598
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        gc.collect()
        gc.freeze()

599
600
601
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()

Rui Qiao's avatar
Rui Qiao committed
602
    @contextmanager
603
604
605
606
607
608
609
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
        client_handshake_address: Optional[str],
Rui Qiao's avatar
Rui Qiao committed
610
    ) -> Generator[EngineZmqAddresses, None, None]:
611
612
613
614
615
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

616
        For DP>1 with internal load-balancing this is with the shared front-end
617
618
        process which may reside on a different node.

619
        For DP>1 with external or hybrid load-balancing, two handshakes are
620
        performed:
621
622
623
624
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
625
626
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
627
628
629
630
631
632

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
633
        input_ctx = zmq.Context()
634
        is_local = local_client and client_handshake_address is None
635
        headless = not local_client
636
637
638
639
640
641
642
643
644
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
645
646
647
648
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
649
            assert local_client
650
            local_handshake = self._perform_handshake(
651
652
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
668
        headless: bool,
669
670
671
        vllm_config: VllmConfig,
        parallel_config_to_update: Optional[ParallelConfig] = None,
    ) -> Generator[EngineZmqAddresses, None, None]:
672
673
674
675
676
677
678
679
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
680
            # Register engine with front-end.
681
682
683
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
684
685
686
687
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
688
689
690
691
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
692
693
694
695
696
697
698
699
700
701
702
703

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
704
                )
705
706

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
707

708
    @staticmethod
709
    def startup_handshake(
710
711
        handshake_socket: zmq.Socket,
        local_client: bool,
712
        headless: bool,
713
714
        parallel_config: Optional[ParallelConfig] = None,
    ) -> EngineZmqAddresses:
715
        # Send registration message.
716
        handshake_socket.send(
717
718
719
720
721
722
723
724
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
725
726
727

        # Receive initialization message.
        logger.info("Waiting for init message from front-end.")
728
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
729
730
731
732
733
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
734
735
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
736
737
            init_bytes, type=EngineHandshakeMetadata
        )
738
739
        logger.debug("Received init message: %s", init_message)

740
741
742
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
743

744
        return init_message.addresses
745
746

    @staticmethod
747
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
748
749
        """Launch EngineCore busy loop in background process."""

750
751
752
753
754
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

755
756
757
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

758
759
760
761
762
763
764
765
766
767
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

768
        engine_core: Optional[EngineCoreProc] = None
769
        try:
770
            parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
771
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
772
                set_process_title("EngineCore", f"DP{dp_rank}")
773
                decorate_logs()
774
775
776
777
778
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
779
                set_process_title("EngineCore")
780
                decorate_logs()
781
782
                engine_core = EngineCoreProc(*args, **kwargs)

783
784
            engine_core.run_busy_loop()

785
        except SystemExit:
786
            logger.debug("EngineCore exiting.")
787
            raise
788
789
790
791
792
793
794
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
795
796
797
798
        finally:
            if engine_core is not None:
                engine_core.shutdown()

799
800
801
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

802
803
804
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

805
806
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
807
            # 1) Poll the input queue until there is work to do.
808
809
810
811
812
813
814
815
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
816
817
818
819
820
        while (
            not self.engines_running
            and not self.scheduler.has_requests()
            and not self.batch_queue
        ):
821
822
823
824
825
826
827
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
828
            logger.debug("EngineCore loop active.")
829
830
831
832
833
834

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

835
    def _process_engine_step(self) -> bool:
836
837
838
        """Called only when there are unfinished local requests."""

        # Step the engine core.
839
        outputs, model_executed = self.step_fn()
840
        # Put EngineCoreOutputs into the output queue.
841
        for output in outputs.items() if outputs else ():
842
            self.output_queue.put_nowait(output)
843
844
        # Post-step hook.
        self.post_step(model_executed)
845

846
847
        return model_executed

848
849
850
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
851
        """Dispatch request from client."""
852

853
        if request_type == EngineCoreRequestType.ADD:
854
855
            req, request_wave = request
            self.add_request(req, request_wave)
856
        elif request_type == EngineCoreRequestType.ABORT:
857
            self.abort_requests(request)
858
        elif request_type == EngineCoreRequestType.UTILITY:
859
            client_idx, call_id, method_name, args = request
860
861
862
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
863
864
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
865
866
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
867
868
869
                output.failure_message = (
                    f"Call to {method_name} method failed: {str(e)}"
                )
870
            self.output_queue.put_nowait(
871
872
                (client_idx, EngineCoreOutputs(utility_output=output))
            )
873
874
875
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
876
877
878
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
879
880
881
882

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
883
        arg type, try converting to msgspec object."""
884
885
886
887
888
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
889
890
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
891
            and issubclass(p.annotation, msgspec.Struct)
892
893
894
895
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
896

897
898
899
900
901
902
903
904
905
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
906
907
908
909
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
910

911
912
913
914
915
916
917
    def process_input_sockets(
        self,
        input_addresses: list[str],
        coord_input_address: Optional[str],
        identity: bytes,
        ready_event: threading.Event,
    ):
918
919
920
        """Input socket IO thread."""

        # Msgpack serialization decoding.
921
922
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
923

924
925
926
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
927
928
929
930
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
931
932
933
934
935
936
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
937
938
939
940
941
942
943
944
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
945
                # Send subscription message to coordinator.
946
                coord_socket.send(b"\x01")
947
948
949
950
951
952
953

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
954
                input_socket.send(b"")
955
                poller.register(input_socket, zmq.POLLIN)
956

957
            if coord_socket is not None:
958
959
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
960
                poller.register(coord_socket, zmq.POLLIN)
961

962
963
            ready_event.set()
            del ready_event
964
965
966
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
967
968
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
969
970

                    # Deserialize the request data.
971
972
973
974
975
                    if request_type == EngineCoreRequestType.ADD:
                        request = add_request_decoder.decode(data_frames)
                        request = self.preprocess_add_request(request)
                    else:
                        request = generic_decoder.decode(data_frames)
976
977
978
979

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

980
981
982
983
984
985
    def process_output_sockets(
        self,
        output_paths: list[str],
        coord_output_path: Optional[str],
        engine_index: int,
    ):
986
987
988
        """Output socket IO thread."""

        # Msgpack serialization encoding.
989
        encoder = MsgpackEncoder()
990
991
992
993
994
995
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
996

997
998
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
999
1000
1001
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1002
1003
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1004
1005
                for output_path in output_paths
            ]
1006
1007
1008
1009
1010
1011
1012
1013
1014
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1015
1016
            max_reuse_bufs = len(sockets) + 1

1017
            while True:
1018
1019
1020
1021
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1022
                    break
1023
1024
                assert not isinstance(output, bytes)
                client_index, outputs = output
1025
                outputs.engine_index = engine_index
1026

1027
1028
1029
1030
1031
1032
1033
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1034
1035
1036
1037
1038
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1039
                buffers = encoder.encode_into(outputs, buffer)
1040
1041
1042
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1043
1044
1045
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1046
1047
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1048
                    reuse_buffers.append(buffer)
1049
1050
1051
1052
1053
1054
1055
1056
1057


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1058
        local_client: bool,
1059
        handshake_address: str,
1060
1061
        executor_class: type[Executor],
        log_stats: bool,
1062
        client_handshake_address: Optional[str] = None,
1063
    ):
1064
1065
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1066
        self.step_counter = 0
1067
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1068
        self.last_counts = (0, 0)
1069
1070
1071

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1072
1073
1074
1075
1076
1077
1078
1079
1080
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
            dp_rank,
        )
1081
1082
1083

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1084
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1085
        dp_size = vllm_config.parallel_config.data_parallel_size
1086
1087
1088
1089
1090
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1091
1092
1093
1094
1095
1096
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
1097
1098
1099
1100
            logger.debug(
                "Setting kv_transfer_config.engine_id to %s",
                vllm_config.kv_transfer_config.engine_id,
            )
1101

1102
        self.dp_rank = dp_rank
1103
1104
1105
1106
1107
1108
1109
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1110
1111
1112
1113
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1114
1115
1116
1117
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1118
1119
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1120

1121
        super().add_request(request, request_wave)
1122

1123
1124
1125
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1126
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1127
1128
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1129
1130
                new_wave >= self.current_wave
            ):
1131
1132
                self.current_wave = new_wave
                if not self.engines_running:
1133
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1134
1135
1136
1137
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1138
    def _maybe_publish_request_counts(self):
1139
        if not self.publish_dp_lb_stats:
1140
1141
1142
1143
1144
1145
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1146
1147
1148
1149
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1150

1151
1152
1153
1154
1155
1156
1157
1158
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1159
1160
            # 2) Step the engine core.
            executed = self._process_engine_step()
1161
1162
            self._maybe_publish_request_counts()

1163
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1164
1165
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1166
1167
1168
                    # All engines are idle.
                    continue

1169
1170
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1171
1172
1173
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1174
            self.engines_running = self._has_global_unfinished_reqs(
1175
1176
                local_unfinished_reqs
            )
1177

1178
            if not self.engines_running:
1179
                if self.dp_rank == 0 or not self.has_coordinator:
1180
                    # Notify client that we are pausing the loop.
1181
1182
1183
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1184
1185
1186
1187
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1188
                    self.output_queue.put_nowait(
1189
1190
1191
1192
1193
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1194
                # Increment wave count and reset step counter.
1195
                self.current_wave += 1
1196
                self.step_counter = 0
1197
1198

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1199
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1200
1201
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1202
1203
            return True

1204
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1205

1206
    def reinitialize_distributed(
1207
1208
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1209
1210
1211
1212
1213
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1214
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1215
        if reconfig_request.new_data_parallel_rank != -1:
1216
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1217
        # local rank specifies device visibility, it should not be changed
1218
1219
1220
1221
1222
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1223
            reconfig_request.new_data_parallel_master_ip
1224
1225
        )
        parallel_config.data_parallel_master_port = (
1226
            reconfig_request.new_data_parallel_master_port
1227
        )
1228
1229
1230
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1231
        reconfig_request.new_data_parallel_master_port = (
1232
            parallel_config.data_parallel_master_port
1233
        )
1234
1235
1236
1237
1238
1239
1240
1241

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1242
1243
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1244
1245
1246
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1247
1248
1249
1250
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1251
1252
1253
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1254
1255
1256
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1257

Rui Qiao's avatar
Rui Qiao committed
1258
1259
1260
1261
1262
1263
1264
1265
1266

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1267
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1268
1269
1270
1271
1272
1273
1274
1275
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
1276
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1277

1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1288
1289
1290
1291
1292
1293
1294
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1295
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1296

1297
        super().__init__(vllm_config, local_client, "", executor_class, log_stats)
Rui Qiao's avatar
Rui Qiao committed
1298

1299
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1300
        from vllm.platforms import current_platform
1301

1302
1303
1304
1305
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1306
1307
1308
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1309

1310
1311
1312
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1313
1314
1315
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1316
1317
1318
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1319
            os.environ[device_control_env_var] = value
1320
1321
1322
1323
1324
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1325
1326
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1327

Rui Qiao's avatar
Rui Qiao committed
1328
    @contextmanager
1329
1330
1331
1332
1333
1334
1335
1336
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
        client_handshake_address: Optional[str],
    ):
Rui Qiao's avatar
Rui Qiao committed
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()