core.py 51.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
7
import threading
import time
8
from collections import deque
Rui Qiao's avatar
Rui Qiao committed
9
from collections.abc import Generator
10
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
11
from contextlib import ExitStack, contextmanager
12
from inspect import isclass, signature
13
from logging import DEBUG
14
from typing import Any, Callable, Optional, TypeVar, Union
15

16
import msgspec
17
18
import zmq

19
20
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
21
from vllm.logger import init_logger
22
from vllm.logging_utils.dump_input import dump_engine_exception
23
from vllm.lora.request import LoRARequest
24
from vllm.multimodal import MULTIMODAL_REGISTRY
25
from vllm.tasks import POOLING_TASKS, SupportedTask
26
27
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
28
from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket,
29
                        resolve_obj_by_qualname, set_process_title)
30
31
32
from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config,
                                         get_request_block_hasher,
                                         init_none_hash,
33
                                         unify_kv_cache_configs)
34
from vllm.v1.core.sched.interface import SchedulerInterface
35
36
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
37
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
38
39
                            EngineCoreRequestType,
                            ReconfigureDistributedRequest, ReconfigureRankType,
40
                            UtilityOutput, UtilityResult)
41
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheServer
42
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
43
from vllm.v1.executor.abstract import Executor
44
from vllm.v1.kv_cache_interface import KVCacheConfig
45
from vllm.v1.metrics.stats import SchedulerStats
46
from vllm.v1.outputs import ModelRunnerOutput
47
from vllm.v1.request import Request, RequestStatus
48
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
49
from vllm.v1.structured_output import StructuredOutputManager
50
51
52
53
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

54
POLLING_TIMEOUT_S = 2.5
55
HANDSHAKE_TIMEOUT_MINS = 5
56

57
58
_R = TypeVar('_R')  # Return type for collective_rpc

59
60
61
62

class EngineCore:
    """Inner loop of vLLM's Engine."""

63
64
65
66
67
    def __init__(self,
                 vllm_config: VllmConfig,
                 executor_class: type[Executor],
                 log_stats: bool,
                 executor_fail_callback: Optional[Callable] = None):
68

69
70
71
72
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
        load_general_plugins()

73
        self.vllm_config = vllm_config
74
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
75
76
                    VLLM_VERSION, vllm_config)

77
78
        self.log_stats = log_stats

79
80
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
81
82
83
        if executor_fail_callback is not None:
            self.model_executor.register_failure_callback(
                executor_fail_callback)
84

85
86
        self.available_gpu_memory_for_kv_cache = -1

87
        # Setup KV Caches and update CacheConfig after profiling.
88
89
90
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
            self._initialize_kv_caches(vllm_config)

91
92
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
93
94
        self.collective_rpc("initialize_cache",
                            args=(num_gpu_blocks, num_cpu_blocks))
95

96
97
        self.structured_output_manager = StructuredOutputManager(vllm_config)

98
        # Setup scheduler.
99
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
100
101
102
103
104
105
106
107
108
            Scheduler = resolve_obj_by_qualname(
                vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
109
110
111
112
113
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
                vllm_config.scheduler_config.scheduler_cls)
114

115
116
117
118
119
120
        if len(kv_cache_config.kv_cache_groups) == 0:
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
            logger.info("Disabling chunked prefill for model without KVCache")
            vllm_config.scheduler_config.chunked_prefill_enabled = False

121
        self.scheduler: SchedulerInterface = Scheduler(
122
            vllm_config=vllm_config,
123
124
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
125
126
            include_finished_set=vllm_config.parallel_config.data_parallel_size
            > 1,
127
            log_stats=self.log_stats,
128
        )
129

130
        self.mm_input_cache_server = MultiModalInputCacheServer(
131
            vllm_config.model_config, MULTIMODAL_REGISTRY)
132

133
134
135
136
137
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
138
        self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
139
140
141
142
143
144
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)

145
146
147
148
149
150
151
152
153
154
155
156
157
        self.request_block_hasher: Optional[Callable[[Request],
                                                     list[BlockHash]]] = None
        if (self.vllm_config.cache_config.enable_prefix_caching
                or self.scheduler.get_kv_connector() is not None):

            block_size = vllm_config.cache_config.block_size
            caching_hash_fn = get_hash_fn_by_name(
                vllm_config.cache_config.prefix_caching_hash_algo)
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
                block_size, caching_hash_fn)

158
159
    def _initialize_kv_caches(
            self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
160
        start = time.time()
161

162
        # Get all kv cache needed by the model
163
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
164

165
166
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
                self.available_gpu_memory_for_kv_cache = \
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
                available_gpu_memory = [
                    self.available_gpu_memory_for_kv_cache
                ] * len(kv_cache_specs)
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
                available_gpu_memory = (
                    self.model_executor.determine_available_memory())
                self.available_gpu_memory_for_kv_cache = \
                    available_gpu_memory[0]
182
183
184
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
185

186
        assert len(kv_cache_specs) == len(available_gpu_memory)
187
        # Get the kv cache tensor size
188
189
190
191
192
193
194
195
196
197
198
199
200
        kv_cache_configs = [
            get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
                                available_gpu_memory_one_worker)
            for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
            zip(kv_cache_specs, available_gpu_memory)
        ]

        # Since we use a shared centralized controller, we need the
        # `kv_cache_config` to be consistent across all workers to make sure
        # all the memory operators can be applied to all workers.
        unify_kv_cache_configs(kv_cache_configs)

        # All workers have the same kv_cache_config except layer names, so use
201
        # an arbitrary one to initialize the scheduler.
202
203
204
205
206
        assert all([
            cfg.num_blocks == kv_cache_configs[0].num_blocks
            for cfg in kv_cache_configs
        ])
        num_gpu_blocks = kv_cache_configs[0].num_blocks
207
        num_cpu_blocks = 0
208
        scheduler_kv_cache_config = kv_cache_configs[0]
209
210

        # Initialize kv cache and warmup the execution
211
        self.model_executor.initialize_from_config(kv_cache_configs)
212

213
214
215
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
216
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
217

218
219
220
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

221
222
223
224
225
226
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
        
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
227
228
229
230
231
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
                f"request_id must be a string, got {type(request.request_id)}")

232
        if pooling_params := request.pooling_params:
233
234
235
236
237
            supported_pooling_tasks = [
                task for task in self.get_supported_tasks()
                if task in POOLING_TASKS
            ]

238
239
240
            if pooling_params.task not in supported_pooling_tasks:
                raise ValueError(f"Unsupported task: {pooling_params.task!r} "
                                 f"Supported tasks: {supported_pooling_tasks}")
241

242
        if request.kv_transfer_params is not None and (
243
244
245
                not self.scheduler.get_kv_connector()):
            logger.warning("Got kv_transfer_params, but no KVConnector found. "
                           "Disabling KVTransfer for this request.")
Robert Shaw's avatar
Robert Shaw committed
246

247
        self.scheduler.add_request(request)
248

249
    def abort_requests(self, request_ids: list[str]):
250
251
252
253
254
255
256
257
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

258
259
260
261
262
263
    def execute_model_with_error_logging(
        self,
        model_fn: Callable[[SchedulerOutput], ModelRunnerOutput],
        scheduler_output: SchedulerOutput,
    ) -> ModelRunnerOutput:
        """Execute the model and log detailed info on failure."""
264
        try:
265
            return model_fn(scheduler_output)
266
267
268
269
270
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

271
272
273
274
275
            # NOTE: This method is exception-free
            dump_engine_exception(self.vllm_config, scheduler_output,
                                  self.scheduler.make_stats())
            raise err

276
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
277
278
279
280
281
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
282

283
284
285
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
286
            return {}, False
287
        scheduler_output = self.scheduler.schedule()
288
289
290
        model_output = self.execute_model_with_error_logging(
            self.model_executor.execute_model,  # type: ignore
            scheduler_output)
291
        engine_core_outputs = self.scheduler.update_from_output(
292
            scheduler_output, model_output)  # type: ignore
293

294
295
        return (engine_core_outputs,
                scheduler_output.total_num_scheduled_tokens > 0)
296

297
    def step_with_batch_queue(
298
            self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
299
300
301
302
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
303
304
305
306
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
307
308
309
310
311
312
313
314
315
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
316
317
318
319
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
        if not self.batch_queue.full():
320
321
322
323
324
325
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

326
327
328
329
        scheduled_batch = (scheduler_output is not None
                           and scheduler_output.total_num_scheduled_tokens > 0)

        # If no more requests can be scheduled and the job queue is not empty,
330
        # block until the first batch in the job queue is finished.
331
332
333
334
        # TODO(comaniac): Ideally we should peek the first batch in the
        # job queue to check if it's finished before scheduling a new batch,
        # but peeking the first element in a queue is not thread-safe,
        # so we need more work.
335
336
        if not scheduled_batch and not self.batch_queue.empty():
            future, scheduler_output = self.batch_queue.get_nowait()
337

338
            # Blocking until the first result is available.
339
340
341
            model_output = self.execute_model_with_error_logging(
                lambda _: future.result(), scheduler_output)

342
            self.batch_queue.task_done()
343
344
            engine_core_outputs = (self.scheduler.update_from_output(
                scheduler_output, model_output))
345

346
        return engine_core_outputs, scheduled_batch
347

348
    def shutdown(self):
349
        self.structured_output_manager.clear_backend()
350
351
        if self.model_executor:
            self.model_executor.shutdown()
352
353
        if self.scheduler:
            self.scheduler.shutdown()
354

355
    def profile(self, is_start: bool = True):
356
        self.model_executor.profile(is_start)
357

358
359
360
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
361
        if self.scheduler.has_unfinished_requests():
362
363
364
365
366
            logger.warning("Resetting the multi-modal cache when requests are "
                           "in progress may lead to desynced internal caches.")

        self.mm_input_cache_server.reset()

367
368
369
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

370
371
372
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

373
374
    def wake_up(self, tags: Optional[list[str]] = None):
        self.model_executor.wake_up(tags)
375

376
377
378
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

379
380
381
    def execute_dummy_batch(self):
        self.model_executor.collective_rpc("execute_dummy_batch")

382
383
384
385
386
387
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

388
    def list_loras(self) -> set[int]:
389
390
391
392
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
393

394
395
396
397
398
399
400
401
402
403
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_executor.save_sharded_state(path=path,
                                               pattern=pattern,
                                               max_size=max_size)

404
405
406
407
408
409
410
411
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

412
413
414
415
416
417
418
    def save_tensorized_model(
        self,
        tensorizer_config,
    ) -> None:
        self.model_executor.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

419
420
421
422
423
424
425
426
    def preprocess_add_request(
            self, request: EngineCoreRequest) -> tuple[Request, int]:
        """Preprocess the request.
        
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
        if request.mm_hashes is not None:
427
428
            assert request.mm_kwargs is not None

429
430
431
            # Note on thread safety: no race condition.
            # `mm_input_cache_server` is reset at the end of LLMEngine init,
            # and will only accessed in the input processing thread afterwards.
432
433
            request.mm_kwargs = self.mm_input_cache_server.get_and_update(
                request.mm_kwargs, request.mm_hashes)
434

435
436
        req = Request.from_engine_core_request(request,
                                               self.request_block_hasher)
437
438
439
440
441
442
443
444
445
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

446
447
448
449

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

450
451
    ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'

452
453
    def __init__(
        self,
454
        vllm_config: VllmConfig,
455
        local_client: bool,
456
        handshake_address: str,
457
        executor_class: type[Executor],
458
        log_stats: bool,
459
        client_handshake_address: Optional[str] = None,
460
        engine_index: int = 0,
461
    ):
Rui Qiao's avatar
Rui Qiao committed
462
463
464
465
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
        self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
                                              bytes]]()
        executor_fail_callback = lambda: self.input_queue.put_nowait(
466
467
            (EngineCoreRequestType.EXECUTOR_FAILED, b''))

Rui Qiao's avatar
Rui Qiao committed
468
469
470
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
471

472
473
474
        with self._perform_handshakes(handshake_address, identity,
                                      local_client, vllm_config,
                                      client_handshake_address) as addresses:
475
            self.client_count = len(addresses.outputs)
476
477

            # Set up data parallel environment.
478
            self.has_coordinator = addresses.coordinator_output is not None
479
480
            self.frontend_stats_publish_address = (
                addresses.frontend_stats_publish_address)
481
482
483
            logger.debug("Has DP Coordinator: %s, stats publish address: %s",
                         self.has_coordinator,
                         self.frontend_stats_publish_address)
484
            # Only publish request queue stats to coordinator for "internal"
485
            # and "hybrid" LB modes .
486
487
488
489
            self.publish_dp_lb_stats = (
                self.has_coordinator
                and not vllm_config.parallel_config.data_parallel_external_lb)

490
491
492
493
494
            self._init_data_parallel(vllm_config)

            super().__init__(vllm_config, executor_class, log_stats,
                             executor_fail_callback)

495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
            input_thread = threading.Thread(target=self.process_input_sockets,
                                            args=(addresses.inputs,
                                                  addresses.coordinator_input,
                                                  identity, ready_event),
                                            daemon=True)
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
                args=(addresses.outputs, addresses.coordinator_output,
                      self.engine_index),
                daemon=True)
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
                    raise RuntimeError(
                        "Input socket thread died during startup")
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

Rui Qiao's avatar
Rui Qiao committed
524
525
        self.step_fn = (self.step if self.batch_queue is None else
                        self.step_with_batch_queue)
526

Rui Qiao's avatar
Rui Qiao committed
527
    @contextmanager
528
529
530
531
532
533
534
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
        client_handshake_address: Optional[str],
Rui Qiao's avatar
Rui Qiao committed
535
    ) -> Generator[EngineZmqAddresses, None, None]:
536
537
538
539
540
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

541
        For DP>1 with internal load-balancing this is with the shared front-end
542
543
        process which may reside on a different node.

544
        For DP>1 with external or hybrid load-balancing, two handshakes are
545
        performed:
546
547
548
549
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
550
551
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
552
553
554
555
556
557

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
558
        input_ctx = zmq.Context()
559
        is_local = local_client and client_handshake_address is None
560
        headless = not local_client
561
        handshake = self._perform_handshake(input_ctx, handshake_address,
562
563
                                            identity, is_local, headless,
                                            vllm_config,
564
565
566
567
568
                                            vllm_config.parallel_config)
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
569
            assert local_client
570
            local_handshake = self._perform_handshake(
571
                input_ctx, client_handshake_address, identity, True, False,
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
                vllm_config)
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
588
        headless: bool,
589
590
591
592
        vllm_config: VllmConfig,
        parallel_config_to_update: Optional[ParallelConfig] = None,
    ) -> Generator[EngineZmqAddresses, None, None]:
        with make_zmq_socket(ctx,
Rui Qiao's avatar
Rui Qiao committed
593
594
595
596
597
598
                             handshake_address,
                             zmq.DEALER,
                             identity=identity,
                             linger=5000,
                             bind=False) as handshake_socket:
            # Register engine with front-end.
599
            addresses = self.startup_handshake(handshake_socket, local_client,
600
                                               headless,
601
                                               parallel_config_to_update)
Rui Qiao's avatar
Rui Qiao committed
602
603
604
605
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
606
607
608
609
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
Rui Qiao's avatar
Rui Qiao committed
610
611
612
            handshake_socket.send(
                msgspec.msgpack.encode({
                    "status": "READY",
613
                    "local": local_client,
614
                    "headless": headless,
Rui Qiao's avatar
Rui Qiao committed
615
                    "num_gpu_blocks": num_gpu_blocks,
616
                    "dp_stats_address": dp_stats_address,
Rui Qiao's avatar
Rui Qiao committed
617
618
                }))

619
    @staticmethod
620
    def startup_handshake(
621
622
        handshake_socket: zmq.Socket,
        local_client: bool,
623
        headless: bool,
624
625
        parallel_config: Optional[ParallelConfig] = None,
    ) -> EngineZmqAddresses:
626
627

        # Send registration message.
628
        handshake_socket.send(
629
630
            msgspec.msgpack.encode({
                "status": "HELLO",
631
                "local": local_client,
632
                "headless": headless,
633
634
635
636
            }))

        # Receive initialization message.
        logger.info("Waiting for init message from front-end.")
637
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
638
639
640
            raise RuntimeError("Did not receive response from front-end "
                               f"process within {HANDSHAKE_TIMEOUT_MINS} "
                               f"minutes")
641
642
643
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
            init_bytes, type=EngineHandshakeMetadata)
644
645
        logger.debug("Received init message: %s", init_message)

646
647
648
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
649

650
        return init_message.addresses
651
652

    @staticmethod
653
654
655
656
    def run_engine_core(*args,
                        dp_rank: int = 0,
                        local_dp_rank: int = 0,
                        **kwargs):
657
658
        """Launch EngineCore busy loop in background process."""

659
660
661
662
663
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

664
665
666
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

667
668
669
670
671
672
673
674
675
676
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

677
        engine_core: Optional[EngineCoreProc] = None
678
        try:
679
680
            parallel_config: ParallelConfig = kwargs[
                "vllm_config"].parallel_config
681
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
682
                set_process_title("DPEngineCore", str(dp_rank))
683
                decorate_logs()
684
685
686
687
688
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
689
                set_process_title("EngineCore")
690
                decorate_logs()
691
692
                engine_core = EngineCoreProc(*args, **kwargs)

693
694
            engine_core.run_busy_loop()

695
        except SystemExit:
696
            logger.debug("EngineCore exiting.")
697
            raise
698
699
700
701
702
703
704
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
705
706
707
708
        finally:
            if engine_core is not None:
                engine_core.shutdown()

709
710
711
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

712
713
714
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

715
716
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
717
            # 1) Poll the input queue until there is work to do.
718
719
720
721
722
723
724
725
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
726
        while not self.engines_running and not self.scheduler.has_requests():
727
728
729
730
731
732
733
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
734
            logger.debug("EngineCore loop active.")
735
736
737
738
739
740

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

741
    def _process_engine_step(self) -> bool:
742
743
744
        """Called only when there are unfinished local requests."""

        # Step the engine core.
745
        outputs, model_executed = self.step_fn()
746
        # Put EngineCoreOutputs into the output queue.
747
748
        for output in (outputs.items() if outputs else ()):
            self.output_queue.put_nowait(output)
749

750
751
        return model_executed

752
753
754
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
755

756
        if request_type == EngineCoreRequestType.ADD:
757
758
            req, request_wave = request
            self.add_request(req, request_wave)
759
        elif request_type == EngineCoreRequestType.ABORT:
760
            self.abort_requests(request)
761
        elif request_type == EngineCoreRequestType.UTILITY:
762
            client_idx, call_id, method_name, args = request
763
764
765
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
766
767
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
768
769
770
771
772
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
773
                (client_idx, EngineCoreOutputs(utility_output=output)))
774
775
776
777
778
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
            logger.error("Unrecognized input request type encountered: %s",
                         request_type)
779
780
781
782
783
784
785
786
787
788
789
790
791
792

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
793

794
795
796
797
798
799
800
801
802
803
804
805
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
            logger.fatal("vLLM shutdown signal from EngineCore failed "
                         "to send. Please report this issue.")

806
807
    def process_input_sockets(self, input_addresses: list[str],
                              coord_input_address: Optional[str],
808
                              identity: bytes, ready_event: threading.Event):
809
810
811
        """Input socket IO thread."""

        # Msgpack serialization decoding.
812
813
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
814

815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx,
                                    input_address,
                                    zmq.DEALER,
                                    identity=identity,
                                    bind=False))
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
                    make_zmq_socket(ctx,
                                    coord_input_address,
                                    zmq.XSUB,
                                    identity=identity,
                                    bind=False))
                # Send subscription message to coordinator.
                coord_socket.send(b'\x01')

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
                input_socket.send(b'')
                poller.register(input_socket, zmq.POLLIN)
845

846
            if coord_socket is not None:
847
848
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
849
                poller.register(coord_socket, zmq.POLLIN)
850

851
852
            ready_event.set()
            del ready_event
853
854
855
856
857
858
859
860
861
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
                    type_frame, *data_frames = input_socket.recv_multipart(
                        copy=False)
                    request_type = EngineCoreRequestType(
                        bytes(type_frame.buffer))

                    # Deserialize the request data.
862
863
864
865
866
                    if request_type == EngineCoreRequestType.ADD:
                        request = add_request_decoder.decode(data_frames)
                        request = self.preprocess_add_request(request)
                    else:
                        request = generic_decoder.decode(data_frames)
867
868
869
870
871
872
873

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

    def process_output_sockets(self, output_paths: list[str],
                               coord_output_path: Optional[str],
                               engine_index: int):
874
875
876
        """Output socket IO thread."""

        # Msgpack serialization encoding.
877
        encoder = MsgpackEncoder()
878
879
880
881
882
883
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
884

885
886
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
887
888
889
890
891
892
893
894
895
896
897
898
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
                for output_path in output_paths
            ]
            coord_socket = stack.enter_context(
                make_zmq_socket(
                    ctx, coord_output_path, zmq.PUSH, bind=False,
                    linger=4000)) if coord_output_path is not None else None
            max_reuse_bufs = len(sockets) + 1

899
            while True:
900
901
902
903
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
904
                    break
905
906
                assert not isinstance(output, bytes)
                client_index, outputs = output
907
                outputs.engine_index = engine_index
908

909
910
911
912
913
914
915
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

916
917
918
919
920
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
921
                buffers = encoder.encode_into(outputs, buffer)
922
923
924
                tracker = sockets[client_index].send_multipart(buffers,
                                                               copy=False,
                                                               track=True)
925
926
927
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
928
929
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
930
                    reuse_buffers.append(buffer)
931
932
933
934
935
936
937
938
939


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
940
        local_client: bool,
941
        handshake_address: str,
942
943
        executor_class: type[Executor],
        log_stats: bool,
944
        client_handshake_address: Optional[str] = None,
945
    ):
946
947
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
948
        self.step_counter = 0
949
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
950
        self.last_counts = (0, 0)
951
952
953

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
954
955
956
        super().__init__(vllm_config, local_client, handshake_address,
                         executor_class, log_stats, client_handshake_address,
                         dp_rank)
957
958
959
960

    def _init_data_parallel(self, vllm_config: VllmConfig):

        # Configure GPUs and stateless process group for data parallel.
961
        dp_rank = vllm_config.parallel_config.data_parallel_rank
962
        dp_size = vllm_config.parallel_config.data_parallel_size
963
964
965
966
967
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

968
969
970
971
972
973
974
975
976
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
            logger.debug("Setting kv_transfer_config.engine_id to %s",
                         vllm_config.kv_transfer_config.engine_id)

977
        self.dp_rank = dp_rank
978
979
980
981
982
983
984
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

985
986
987
988
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
989
990
991
992
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
993
                    (-1, EngineCoreOutputs(start_wave=self.current_wave)))
994

995
        super().add_request(request, request_wave)
996
997
998
999

    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1000
1001
1002
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
                    new_wave >= self.current_wave):
1003
1004
1005
1006
1007
1008
1009
1010
                self.current_wave = new_wave
                if not self.engines_running:
                    logger.debug("EngineCore starting idle loop for wave %d.",
                                 new_wave)
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1011
    def _maybe_publish_request_counts(self):
1012
        if not self.publish_dp_lb_stats:
1013
1014
1015
1016
1017
1018
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1019
1020
1021
            stats = SchedulerStats(*counts,
                                   step_counter=self.step_counter,
                                   current_wave=self.current_wave)
1022
1023
1024
            self.output_queue.put_nowait(
                (-1, EngineCoreOutputs(scheduler_stats=stats)))

1025
1026
1027
1028
1029
1030
1031
1032
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1033
1034
            # 2) Step the engine core.
            executed = self._process_engine_step()
1035
1036
            self._maybe_publish_request_counts()

1037
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1038
1039
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1040
1041
1042
                    # All engines are idle.
                    continue

1043
1044
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1045
1046
1047
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1048
            self.engines_running = self._has_global_unfinished_reqs(
1049
1050
                local_unfinished_reqs)

1051
            if not self.engines_running:
1052
                if self.dp_rank == 0 or not self.has_coordinator:
1053
1054
1055
                    # Notify client that we are pausing the loop.
                    logger.debug("Wave %d finished, pausing engine loop.",
                                 self.current_wave)
1056
1057
1058
1059
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1060
                    self.output_queue.put_nowait(
1061
                        (client_index,
1062
                         EngineCoreOutputs(wave_complete=self.current_wave)))
1063
                # Increment wave count and reset step counter.
1064
                self.current_wave += 1
1065
                self.step_counter = 0
1066
1067
1068

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

1069
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1070
1071
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1072
1073
1074
1075
            return True

        return ParallelConfig.has_unfinished_dp(self.dp_group,
                                                local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1076

1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
    def reinitialize_distributed(
            self, reconfig_request: ReconfigureDistributedRequest) -> None:
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
        parallel_config.data_parallel_size = \
            reconfig_request.new_data_parallel_size
        if reconfig_request.new_data_parallel_rank != -1:
            parallel_config.data_parallel_rank = \
                reconfig_request.new_data_parallel_rank
        # local rank specifies device visibility, it should not be changed
        assert reconfig_request.new_data_parallel_rank_local == \
            ReconfigureRankType.KEEP_CURRENT_RANK
        parallel_config.data_parallel_master_ip = \
            reconfig_request.new_data_parallel_master_ip
        parallel_config.data_parallel_master_port = \
            reconfig_request.new_data_parallel_master_port
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
        reconfig_request.new_data_parallel_master_port = \
            parallel_config.data_parallel_master_port

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
                self.dp_group, self.available_gpu_memory_for_kv_cache)
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
        if reconfig_request.new_data_parallel_rank == \
        ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
            logger.info("Distributed environment reinitialized for DP rank %s",
                        self.dp_rank)

Rui Qiao's avatar
Rui Qiao committed
1121
1122
1123
1124
1125
1126
1127
1128
1129

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1130
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
        vllm_config.parallel_config.data_parallel_rank_local = \
            local_dp_rank

1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1152
1153
1154
1155
1156
1157
1158
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1159
        self._set_cuda_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1160

1161
        super().__init__(vllm_config, local_client, "", executor_class,
Rui Qiao's avatar
Rui Qiao committed
1162
1163
                         log_stats)

1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
    def _set_cuda_visible_devices(self, vllm_config: VllmConfig,
                                  local_dp_rank: int):
        from vllm.platforms import current_platform
        device_control_env_var = current_platform.device_control_env_var
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
            os.environ[device_control_env_var] = ",".join(
                str(current_platform.device_id_to_physical_device_id(i))
                for i in range(local_dp_rank *
                               world_size, (local_dp_rank + 1) * world_size))
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
                f"base value: \"{os.getenv(device_control_env_var)}\"") from e

Rui Qiao's avatar
Rui Qiao committed
1182
    @contextmanager
1183
1184
1185
    def _perform_handshakes(self, handshake_address: str, identity: bytes,
                            local_client: bool, vllm_config: VllmConfig,
                            client_handshake_address: Optional[str]):
Rui Qiao's avatar
Rui Qiao committed
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()