core.py 52.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import gc
4
import os
5
import queue
6
import signal
7
8
import threading
import time
9
from collections import deque
Rui Qiao's avatar
Rui Qiao committed
10
from collections.abc import Generator
11
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
12
from contextlib import ExitStack, contextmanager
13
from inspect import isclass, signature
14
from logging import DEBUG
15
from typing import Any, Callable, Optional, TypeVar, Union
16

17
import msgspec
18
19
import zmq

20
21
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
22
from vllm.logger import init_logger
23
from vllm.logging_utils.dump_input import dump_engine_exception
24
from vllm.lora.request import LoRARequest
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.multimodal.cache import engine_receiver_cache_from_config
27
from vllm.tasks import POOLING_TASKS, SupportedTask
28
29
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
30
from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket,
31
                        resolve_obj_by_qualname, set_process_title)
32
from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_configs,
33
                                         get_request_block_hasher,
34
                                         init_none_hash)
35
from vllm.v1.core.sched.interface import SchedulerInterface
36
37
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
38
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
39
40
                            EngineCoreRequestType,
                            ReconfigureDistributedRequest, ReconfigureRankType,
41
                            UtilityOutput, UtilityResult)
42
43
from vllm.v1.engine.utils import (EngineHandshakeMetadata, EngineZmqAddresses,
                                  get_device_indices)
44
from vllm.v1.executor.abstract import Executor
45
from vllm.v1.kv_cache_interface import KVCacheConfig
46
from vllm.v1.metrics.stats import SchedulerStats
47
from vllm.v1.outputs import ModelRunnerOutput
48
from vllm.v1.request import Request, RequestStatus
49
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
50
from vllm.v1.structured_output import StructuredOutputManager
51
52
53
54
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

55
POLLING_TIMEOUT_S = 2.5
56
HANDSHAKE_TIMEOUT_MINS = 5
57

58
59
_R = TypeVar('_R')  # Return type for collective_rpc

60
61
62
63

class EngineCore:
    """Inner loop of vLLM's Engine."""

64
65
66
67
68
    def __init__(self,
                 vllm_config: VllmConfig,
                 executor_class: type[Executor],
                 log_stats: bool,
                 executor_fail_callback: Optional[Callable] = None):
69

70
71
72
73
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
        load_general_plugins()

74
        self.vllm_config = vllm_config
75
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
76
77
                    VLLM_VERSION, vllm_config)

78
79
        self.log_stats = log_stats

80
81
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
82
83
84
        if executor_fail_callback is not None:
            self.model_executor.register_failure_callback(
                executor_fail_callback)
85

86
87
        self.available_gpu_memory_for_kv_cache = -1

88
        # Setup KV Caches and update CacheConfig after profiling.
89
90
91
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
            self._initialize_kv_caches(vllm_config)

92
93
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
94
95
        self.collective_rpc("initialize_cache",
                            args=(num_gpu_blocks, num_cpu_blocks))
96

97
98
        self.structured_output_manager = StructuredOutputManager(vllm_config)

99
        # Setup scheduler.
100
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
101
102
103
104
105
106
107
108
109
            Scheduler = resolve_obj_by_qualname(
                vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
110
111
112
113
114
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
                vllm_config.scheduler_config.scheduler_cls)
115

116
117
118
119
120
121
        if len(kv_cache_config.kv_cache_groups) == 0:
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
            logger.info("Disabling chunked prefill for model without KVCache")
            vllm_config.scheduler_config.chunked_prefill_enabled = False

122
        self.scheduler: SchedulerInterface = Scheduler(
123
            vllm_config=vllm_config,
124
125
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
126
127
            include_finished_set=vllm_config.parallel_config.data_parallel_size
            > 1,
128
            log_stats=self.log_stats,
129
        )
130
        self.use_spec_decode = vllm_config.speculative_config is not None
131

132
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
133
        self.mm_receiver_cache = engine_receiver_cache_from_config(
134
            vllm_config, mm_registry)
135

136
137
138
139
140
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
141
142
        self.batch_queue: Optional[deque[tuple[Future[ModelRunnerOutput],
                                               SchedulerOutput]]] = None
143
144
145
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
146
            self.batch_queue = deque(maxlen=self.batch_queue_size)
147

148
149
150
151
152
153
154
155
156
157
158
159
160
        self.request_block_hasher: Optional[Callable[[Request],
                                                     list[BlockHash]]] = None
        if (self.vllm_config.cache_config.enable_prefix_caching
                or self.scheduler.get_kv_connector() is not None):

            block_size = vllm_config.cache_config.block_size
            caching_hash_fn = get_hash_fn_by_name(
                vllm_config.cache_config.prefix_caching_hash_algo)
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
                block_size, caching_hash_fn)

161
162
163
        self.step_fn = (self.step if self.batch_queue is None else
                        self.step_with_batch_queue)

164
165
    def _initialize_kv_caches(
            self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
166
        start = time.time()
167

168
        # Get all kv cache needed by the model
169
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
170

171
172
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
                self.available_gpu_memory_for_kv_cache = \
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
                available_gpu_memory = [
                    self.available_gpu_memory_for_kv_cache
                ] * len(kv_cache_specs)
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
                available_gpu_memory = (
                    self.model_executor.determine_available_memory())
                self.available_gpu_memory_for_kv_cache = \
                    available_gpu_memory[0]
188
189
190
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
191

192
        assert len(kv_cache_specs) == len(available_gpu_memory)
193
194
195

        kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
                                                available_gpu_memory)
196
197

        # All workers have the same kv_cache_config except layer names, so use
198
        # an arbitrary one to initialize the scheduler.
199
200
201
202
203
        assert all([
            cfg.num_blocks == kv_cache_configs[0].num_blocks
            for cfg in kv_cache_configs
        ])
        num_gpu_blocks = kv_cache_configs[0].num_blocks
204
        num_cpu_blocks = 0
205
        scheduler_kv_cache_config = kv_cache_configs[0]
206
207

        # Initialize kv cache and warmup the execution
208
        self.model_executor.initialize_from_config(kv_cache_configs)
209

210
211
212
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
213
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
214

215
216
217
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

218
219
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
220

221
222
223
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
224
225
226
227
228
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
                f"request_id must be a string, got {type(request.request_id)}")

229
        if pooling_params := request.pooling_params:
230
231
232
233
234
            supported_pooling_tasks = [
                task for task in self.get_supported_tasks()
                if task in POOLING_TASKS
            ]

235
236
237
            if pooling_params.task not in supported_pooling_tasks:
                raise ValueError(f"Unsupported task: {pooling_params.task!r} "
                                 f"Supported tasks: {supported_pooling_tasks}")
238

239
        if request.kv_transfer_params is not None and (
240
241
242
                not self.scheduler.get_kv_connector()):
            logger.warning("Got kv_transfer_params, but no KVConnector found. "
                           "Disabling KVTransfer for this request.")
Robert Shaw's avatar
Robert Shaw committed
243

244
        self.scheduler.add_request(request)
245

246
    def abort_requests(self, request_ids: list[str]):
247
248
249
250
251
252
253
254
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

255
256
257
258
259
260
    def execute_model_with_error_logging(
        self,
        model_fn: Callable[[SchedulerOutput], ModelRunnerOutput],
        scheduler_output: SchedulerOutput,
    ) -> ModelRunnerOutput:
        """Execute the model and log detailed info on failure."""
261
        try:
262
            return model_fn(scheduler_output)
263
264
265
266
267
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

268
269
270
271
272
            # NOTE: This method is exception-free
            dump_engine_exception(self.vllm_config, scheduler_output,
                                  self.scheduler.make_stats())
            raise err

273
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
274
275
276
277
278
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
279

280
281
282
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
283
            return {}, False
284
        scheduler_output = self.scheduler.schedule()
285
286
287
        model_output = self.execute_model_with_error_logging(
            self.model_executor.execute_model,  # type: ignore
            scheduler_output)
288
        engine_core_outputs = self.scheduler.update_from_output(
289
            scheduler_output, model_output)  # type: ignore
290

291
292
        return (engine_core_outputs,
                scheduler_output.total_num_scheduled_tokens > 0)
293

294
295
296
297
298
299
300
    def post_step(self, model_executed: bool) -> None:
        if self.use_spec_decode and model_executed:
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

301
    def step_with_batch_queue(
302
            self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
303
304
305
306
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
307
308
309
310
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
311
312
313
314
315
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
316
317
        batch_queue = self.batch_queue
        assert batch_queue is not None
318

319
320
321
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
322
        assert len(batch_queue) < self.batch_queue_size
323

324
325
326
        model_executed = False
        if self.scheduler.has_requests():
            scheduler_output = self.scheduler.schedule()
327
328
            future = self.model_executor.execute_model(scheduler_output,
                                                       non_block=True)
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
            batch_queue.appendleft(
                (future, scheduler_output))  # type: ignore[arg-type]

            model_executed = scheduler_output.total_num_scheduled_tokens > 0
            if model_executed and len(batch_queue) < self.batch_queue_size \
                and not batch_queue[-1][0].done():
                # Don't block on next worker response unless the queue is full
                # or there are no more requests to schedule.
                return None, True

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False

        # Block until the next result is available.
        future, scheduler_output = batch_queue.pop()
        model_output = self.execute_model_with_error_logging(
            lambda _: future.result(), scheduler_output)
349

350
351
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output)
352

353
        return engine_core_outputs, model_executed
354

355
    def shutdown(self):
356
        self.structured_output_manager.clear_backend()
357
358
        if self.model_executor:
            self.model_executor.shutdown()
359
360
        if self.scheduler:
            self.scheduler.shutdown()
361

362
    def profile(self, is_start: bool = True):
363
        self.model_executor.profile(is_start)
364

365
366
367
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
368
        if self.scheduler.has_unfinished_requests():
369
370
371
            logger.warning("Resetting the multi-modal cache when requests are "
                           "in progress may lead to desynced internal caches.")

372
373
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
374

375
376
377
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

378
379
380
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

381
382
    def wake_up(self, tags: Optional[list[str]] = None):
        self.model_executor.wake_up(tags)
383

384
385
386
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

387
    def execute_dummy_batch(self):
388
        self.model_executor.execute_dummy_batch()
389

390
391
392
393
394
395
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

396
    def list_loras(self) -> set[int]:
397
398
399
400
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
401

402
403
404
405
406
407
408
409
410
411
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_executor.save_sharded_state(path=path,
                                               pattern=pattern,
                                               max_size=max_size)

412
413
414
415
416
417
418
419
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

420
421
422
423
424
425
426
    def save_tensorized_model(
        self,
        tensorizer_config,
    ) -> None:
        self.model_executor.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

427
428
429
    def preprocess_add_request(
            self, request: EngineCoreRequest) -> tuple[Request, int]:
        """Preprocess the request.
430

431
432
433
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
434
435
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
436
        # and will only be accessed in the input processing thread afterwards.
437
438
439
440
        if self.mm_receiver_cache is not None and request.mm_features:
            request.mm_features = (
                self.mm_receiver_cache.get_and_update_features(
                    request.mm_features))
441

442
443
        req = Request.from_engine_core_request(request,
                                               self.request_block_hasher)
444
445
446
447
448
449
450
451
452
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

453
454
455
456

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

457
458
    ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'

459
460
    def __init__(
        self,
461
        vllm_config: VllmConfig,
462
        local_client: bool,
463
        handshake_address: str,
464
        executor_class: type[Executor],
465
        log_stats: bool,
466
        client_handshake_address: Optional[str] = None,
467
        engine_index: int = 0,
468
    ):
Rui Qiao's avatar
Rui Qiao committed
469
470
471
472
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
        self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
                                              bytes]]()
        executor_fail_callback = lambda: self.input_queue.put_nowait(
473
474
            (EngineCoreRequestType.EXECUTOR_FAILED, b''))

Rui Qiao's avatar
Rui Qiao committed
475
476
477
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
478

479
480
481
        with self._perform_handshakes(handshake_address, identity,
                                      local_client, vllm_config,
                                      client_handshake_address) as addresses:
482
            self.client_count = len(addresses.outputs)
483
484

            # Set up data parallel environment.
485
            self.has_coordinator = addresses.coordinator_output is not None
486
487
            self.frontend_stats_publish_address = (
                addresses.frontend_stats_publish_address)
488
489
490
            logger.debug("Has DP Coordinator: %s, stats publish address: %s",
                         self.has_coordinator,
                         self.frontend_stats_publish_address)
491
            # Only publish request queue stats to coordinator for "internal"
492
            # and "hybrid" LB modes .
493
494
495
496
            self.publish_dp_lb_stats = (
                self.has_coordinator
                and not vllm_config.parallel_config.data_parallel_external_lb)

497
498
499
500
501
            self._init_data_parallel(vllm_config)

            super().__init__(vllm_config, executor_class, log_stats,
                             executor_fail_callback)

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
            input_thread = threading.Thread(target=self.process_input_sockets,
                                            args=(addresses.inputs,
                                                  addresses.coordinator_input,
                                                  identity, ready_event),
                                            daemon=True)
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
                args=(addresses.outputs, addresses.coordinator_output,
                      self.engine_index),
                daemon=True)
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
                    raise RuntimeError(
                        "Input socket thread died during startup")
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

531
532
533
534
535
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        gc.collect()
        gc.freeze()

Rui Qiao's avatar
Rui Qiao committed
536
    @contextmanager
537
538
539
540
541
542
543
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
        client_handshake_address: Optional[str],
Rui Qiao's avatar
Rui Qiao committed
544
    ) -> Generator[EngineZmqAddresses, None, None]:
545
546
547
548
549
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

550
        For DP>1 with internal load-balancing this is with the shared front-end
551
552
        process which may reside on a different node.

553
        For DP>1 with external or hybrid load-balancing, two handshakes are
554
        performed:
555
556
557
558
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
559
560
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
561
562
563
564
565
566

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
567
        input_ctx = zmq.Context()
568
        is_local = local_client and client_handshake_address is None
569
        headless = not local_client
570
        handshake = self._perform_handshake(input_ctx, handshake_address,
571
572
                                            identity, is_local, headless,
                                            vllm_config,
573
574
575
576
577
                                            vllm_config.parallel_config)
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
578
            assert local_client
579
            local_handshake = self._perform_handshake(
580
                input_ctx, client_handshake_address, identity, True, False,
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
                vllm_config)
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
597
        headless: bool,
598
599
600
601
        vllm_config: VllmConfig,
        parallel_config_to_update: Optional[ParallelConfig] = None,
    ) -> Generator[EngineZmqAddresses, None, None]:
        with make_zmq_socket(ctx,
Rui Qiao's avatar
Rui Qiao committed
602
603
604
605
606
607
                             handshake_address,
                             zmq.DEALER,
                             identity=identity,
                             linger=5000,
                             bind=False) as handshake_socket:
            # Register engine with front-end.
608
            addresses = self.startup_handshake(handshake_socket, local_client,
609
                                               headless,
610
                                               parallel_config_to_update)
Rui Qiao's avatar
Rui Qiao committed
611
612
613
614
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
615
616
617
618
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
Rui Qiao's avatar
Rui Qiao committed
619
620
621
            handshake_socket.send(
                msgspec.msgpack.encode({
                    "status": "READY",
622
                    "local": local_client,
623
                    "headless": headless,
Rui Qiao's avatar
Rui Qiao committed
624
                    "num_gpu_blocks": num_gpu_blocks,
625
                    "dp_stats_address": dp_stats_address,
Rui Qiao's avatar
Rui Qiao committed
626
627
                }))

628
    @staticmethod
629
    def startup_handshake(
630
631
        handshake_socket: zmq.Socket,
        local_client: bool,
632
        headless: bool,
633
634
        parallel_config: Optional[ParallelConfig] = None,
    ) -> EngineZmqAddresses:
635
636

        # Send registration message.
637
        handshake_socket.send(
638
639
            msgspec.msgpack.encode({
                "status": "HELLO",
640
                "local": local_client,
641
                "headless": headless,
642
643
644
645
            }))

        # Receive initialization message.
        logger.info("Waiting for init message from front-end.")
646
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
647
648
649
            raise RuntimeError("Did not receive response from front-end "
                               f"process within {HANDSHAKE_TIMEOUT_MINS} "
                               f"minutes")
650
651
652
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
            init_bytes, type=EngineHandshakeMetadata)
653
654
        logger.debug("Received init message: %s", init_message)

655
656
657
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
658

659
        return init_message.addresses
660
661

    @staticmethod
662
663
664
665
    def run_engine_core(*args,
                        dp_rank: int = 0,
                        local_dp_rank: int = 0,
                        **kwargs):
666
667
        """Launch EngineCore busy loop in background process."""

668
669
670
671
672
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

673
674
675
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

676
677
678
679
680
681
682
683
684
685
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

686
        engine_core: Optional[EngineCoreProc] = None
687
        try:
688
689
            parallel_config: ParallelConfig = kwargs[
                "vllm_config"].parallel_config
690
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
691
                set_process_title("EngineCore", f"DP{dp_rank}")
692
                decorate_logs()
693
694
695
696
697
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
698
                set_process_title("EngineCore")
699
                decorate_logs()
700
701
                engine_core = EngineCoreProc(*args, **kwargs)

702
703
            engine_core.run_busy_loop()

704
        except SystemExit:
705
            logger.debug("EngineCore exiting.")
706
            raise
707
708
709
710
711
712
713
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
714
715
716
717
        finally:
            if engine_core is not None:
                engine_core.shutdown()

718
719
720
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

721
722
723
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

724
725
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
726
            # 1) Poll the input queue until there is work to do.
727
728
729
730
731
732
733
734
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
735
736
        while not self.engines_running and not self.scheduler.has_requests() \
                and not self.batch_queue:
737
738
739
740
741
742
743
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
744
            logger.debug("EngineCore loop active.")
745
746
747
748
749
750

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

751
    def _process_engine_step(self) -> bool:
752
753
754
        """Called only when there are unfinished local requests."""

        # Step the engine core.
755
        outputs, model_executed = self.step_fn()
756
        # Put EngineCoreOutputs into the output queue.
757
758
        for output in (outputs.items() if outputs else ()):
            self.output_queue.put_nowait(output)
759
760
        # Post-step hook.
        self.post_step(model_executed)
761

762
763
        return model_executed

764
765
766
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
767

768
        if request_type == EngineCoreRequestType.ADD:
769
770
            req, request_wave = request
            self.add_request(req, request_wave)
771
        elif request_type == EngineCoreRequestType.ABORT:
772
            self.abort_requests(request)
773
        elif request_type == EngineCoreRequestType.UTILITY:
774
            client_idx, call_id, method_name, args = request
775
776
777
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
778
779
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
780
781
782
783
784
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
785
                (client_idx, EngineCoreOutputs(utility_output=output)))
786
787
788
789
790
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
            logger.error("Unrecognized input request type encountered: %s",
                         request_type)
791
792
793
794
795
796
797
798
799
800
801
802
803
804

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
805

806
807
808
809
810
811
812
813
814
815
816
817
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
            logger.fatal("vLLM shutdown signal from EngineCore failed "
                         "to send. Please report this issue.")

818
819
    def process_input_sockets(self, input_addresses: list[str],
                              coord_input_address: Optional[str],
820
                              identity: bytes, ready_event: threading.Event):
821
822
823
        """Input socket IO thread."""

        # Msgpack serialization decoding.
824
825
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
826

827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx,
                                    input_address,
                                    zmq.DEALER,
                                    identity=identity,
                                    bind=False))
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
                    make_zmq_socket(ctx,
                                    coord_input_address,
                                    zmq.XSUB,
                                    identity=identity,
                                    bind=False))
                # Send subscription message to coordinator.
                coord_socket.send(b'\x01')

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
                input_socket.send(b'')
                poller.register(input_socket, zmq.POLLIN)
857

858
            if coord_socket is not None:
859
860
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
861
                poller.register(coord_socket, zmq.POLLIN)
862

863
864
            ready_event.set()
            del ready_event
865
866
867
868
869
870
871
872
873
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
                    type_frame, *data_frames = input_socket.recv_multipart(
                        copy=False)
                    request_type = EngineCoreRequestType(
                        bytes(type_frame.buffer))

                    # Deserialize the request data.
874
875
876
877
878
                    if request_type == EngineCoreRequestType.ADD:
                        request = add_request_decoder.decode(data_frames)
                        request = self.preprocess_add_request(request)
                    else:
                        request = generic_decoder.decode(data_frames)
879
880
881
882
883
884
885

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

    def process_output_sockets(self, output_paths: list[str],
                               coord_output_path: Optional[str],
                               engine_index: int):
886
887
888
        """Output socket IO thread."""

        # Msgpack serialization encoding.
889
        encoder = MsgpackEncoder()
890
891
892
893
894
895
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
896

897
898
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
899
900
901
902
903
904
905
906
907
908
909
910
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
                for output_path in output_paths
            ]
            coord_socket = stack.enter_context(
                make_zmq_socket(
                    ctx, coord_output_path, zmq.PUSH, bind=False,
                    linger=4000)) if coord_output_path is not None else None
            max_reuse_bufs = len(sockets) + 1

911
            while True:
912
913
914
915
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
916
                    break
917
918
                assert not isinstance(output, bytes)
                client_index, outputs = output
919
                outputs.engine_index = engine_index
920

921
922
923
924
925
926
927
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

928
929
930
931
932
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
933
                buffers = encoder.encode_into(outputs, buffer)
934
935
936
                tracker = sockets[client_index].send_multipart(buffers,
                                                               copy=False,
                                                               track=True)
937
938
939
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
940
941
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
942
                    reuse_buffers.append(buffer)
943
944
945
946
947
948
949
950
951


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
952
        local_client: bool,
953
        handshake_address: str,
954
955
        executor_class: type[Executor],
        log_stats: bool,
956
        client_handshake_address: Optional[str] = None,
957
    ):
958
959
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
960
        self.step_counter = 0
961
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
962
        self.last_counts = (0, 0)
963
964
965

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
966
967
968
        super().__init__(vllm_config, local_client, handshake_address,
                         executor_class, log_stats, client_handshake_address,
                         dp_rank)
969
970
971
972

    def _init_data_parallel(self, vllm_config: VllmConfig):

        # Configure GPUs and stateless process group for data parallel.
973
        dp_rank = vllm_config.parallel_config.data_parallel_rank
974
        dp_size = vllm_config.parallel_config.data_parallel_size
975
976
977
978
979
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

980
981
982
983
984
985
986
987
988
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
            logger.debug("Setting kv_transfer_config.engine_id to %s",
                         vllm_config.kv_transfer_config.engine_id)

989
        self.dp_rank = dp_rank
990
991
992
993
994
995
996
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

997
998
999
1000
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1001
1002
1003
1004
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1005
                    (-1, EngineCoreOutputs(start_wave=self.current_wave)))
1006

1007
        super().add_request(request, request_wave)
1008
1009
1010
1011

    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1012
1013
1014
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
                    new_wave >= self.current_wave):
1015
1016
1017
1018
1019
1020
1021
1022
                self.current_wave = new_wave
                if not self.engines_running:
                    logger.debug("EngineCore starting idle loop for wave %d.",
                                 new_wave)
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1023
    def _maybe_publish_request_counts(self):
1024
        if not self.publish_dp_lb_stats:
1025
1026
1027
1028
1029
1030
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1031
1032
1033
            stats = SchedulerStats(*counts,
                                   step_counter=self.step_counter,
                                   current_wave=self.current_wave)
1034
1035
1036
            self.output_queue.put_nowait(
                (-1, EngineCoreOutputs(scheduler_stats=stats)))

1037
1038
1039
1040
1041
1042
1043
1044
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1045
1046
            # 2) Step the engine core.
            executed = self._process_engine_step()
1047
1048
            self._maybe_publish_request_counts()

1049
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1050
1051
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1052
1053
1054
                    # All engines are idle.
                    continue

1055
1056
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1057
1058
1059
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1060
            self.engines_running = self._has_global_unfinished_reqs(
1061
1062
                local_unfinished_reqs)

1063
            if not self.engines_running:
1064
                if self.dp_rank == 0 or not self.has_coordinator:
1065
1066
1067
                    # Notify client that we are pausing the loop.
                    logger.debug("Wave %d finished, pausing engine loop.",
                                 self.current_wave)
1068
1069
1070
1071
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1072
                    self.output_queue.put_nowait(
1073
                        (client_index,
1074
                         EngineCoreOutputs(wave_complete=self.current_wave)))
1075
                # Increment wave count and reset step counter.
1076
                self.current_wave += 1
1077
                self.step_counter = 0
1078
1079
1080

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

1081
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1082
1083
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1084
1085
1086
1087
            return True

        return ParallelConfig.has_unfinished_dp(self.dp_group,
                                                local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1088

1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
    def reinitialize_distributed(
            self, reconfig_request: ReconfigureDistributedRequest) -> None:
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
        parallel_config.data_parallel_size = \
            reconfig_request.new_data_parallel_size
        if reconfig_request.new_data_parallel_rank != -1:
            parallel_config.data_parallel_rank = \
                reconfig_request.new_data_parallel_rank
        # local rank specifies device visibility, it should not be changed
        assert reconfig_request.new_data_parallel_rank_local == \
            ReconfigureRankType.KEEP_CURRENT_RANK
        parallel_config.data_parallel_master_ip = \
            reconfig_request.new_data_parallel_master_ip
        parallel_config.data_parallel_master_port = \
            reconfig_request.new_data_parallel_master_port
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
        reconfig_request.new_data_parallel_master_port = \
            parallel_config.data_parallel_master_port

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
                self.dp_group, self.available_gpu_memory_for_kv_cache)
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
        if reconfig_request.new_data_parallel_rank == \
        ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
            logger.info("Distributed environment reinitialized for DP rank %s",
                        self.dp_rank)

Rui Qiao's avatar
Rui Qiao committed
1133
1134
1135
1136
1137
1138
1139
1140
1141

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1142
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
        vllm_config.parallel_config.data_parallel_rank_local = \
            local_dp_rank

1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1164
1165
1166
1167
1168
1169
1170
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1171
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1172

1173
        super().__init__(vllm_config, local_client, "", executor_class,
Rui Qiao's avatar
Rui Qiao committed
1174
1175
                         log_stats)

1176
1177
    def _set_visible_devices(self, vllm_config: VllmConfig,
                             local_dp_rank: int):
1178
        from vllm.platforms import current_platform
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
            self._set_cuda_visible_devices(vllm_config, local_dp_rank,
                                           device_control_env_var)

    def _set_cuda_visible_devices(self, vllm_config: VllmConfig,
                                  local_dp_rank: int,
                                  device_control_env_var: str):
1189
1190
1191
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1192
1193
1194
            value = get_device_indices(device_control_env_var, local_dp_rank,
                                       world_size)
            os.environ[device_control_env_var] = value
1195
1196
1197
1198
1199
1200
1201
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
                f"base value: \"{os.getenv(device_control_env_var)}\"") from e

Rui Qiao's avatar
Rui Qiao committed
1202
    @contextmanager
1203
1204
1205
    def _perform_handshakes(self, handshake_address: str, identity: bytes,
                            local_client: bool, vllm_config: VllmConfig,
                            client_handshake_address: Optional[str]):
Rui Qiao's avatar
Rui Qiao committed
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()