core.py 54.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import gc
4
import os
5
import queue
6
import signal
7
8
import threading
import time
9
from collections import deque
10
from collections.abc import Callable, Generator
11
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
12
from contextlib import ExitStack, contextmanager
13
from inspect import isclass, signature
14
from logging import DEBUG
15
from typing import Any, TypeVar, cast
16

17
import msgspec
18
19
import zmq

20
21
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
22
from vllm.envs import enable_envs_cache
23
from vllm.logger import init_logger
24
from vllm.logging_utils.dump_input import dump_engine_exception
25
from vllm.lora.request import LoRARequest
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.cache import engine_receiver_cache_from_config
28
from vllm.tasks import POOLING_TASKS, SupportedTask
29
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
30
from vllm.utils.gc_utils import maybe_attach_gc_debug_callback
31
from vllm.utils.hashing import get_hash_fn_by_name
32
from vllm.utils.network_utils import make_zmq_socket
33
from vllm.utils.system_utils import decorate_logs, set_process_title
34
35
36
37
38
39
40
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
41
from vllm.v1.core.sched.interface import SchedulerInterface
42
from vllm.v1.core.sched.output import SchedulerOutput
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from vllm.v1.engine import (
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
57
from vllm.v1.executor import Executor
58
from vllm.v1.kv_cache_interface import KVCacheConfig
59
from vllm.v1.metrics.stats import SchedulerStats
60
from vllm.v1.outputs import ModelRunnerOutput
61
from vllm.v1.request import Request, RequestStatus
62
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
63
from vllm.v1.structured_output import StructuredOutputManager
64
65
66
67
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

68
POLLING_TIMEOUT_S = 2.5
69
HANDSHAKE_TIMEOUT_MINS = 5
70

71
_R = TypeVar("_R")  # Return type for collective_rpc
72

73
74
75
76

class EngineCore:
    """Inner loop of vLLM's Engine."""

77
78
79
80
81
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
82
        executor_fail_callback: Callable | None = None,
83
    ):
84
85
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
86

87
88
        load_general_plugins()

89
        self.vllm_config = vllm_config
90
        if vllm_config.parallel_config.data_parallel_rank == 0:
91
92
93
94
95
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
96

97
98
        self.log_stats = log_stats

99
100
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
101
        if executor_fail_callback is not None:
102
            self.model_executor.register_failure_callback(executor_fail_callback)
103

104
105
        self.available_gpu_memory_for_kv_cache = -1

106
        # Setup KV Caches and update CacheConfig after profiling.
107
108
109
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
110

111
112
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
113
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
114

115
116
        self.structured_output_manager = StructuredOutputManager(vllm_config)

117
        # Setup scheduler.
118
        Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
119

120
121
122
123
124
125
        if len(kv_cache_config.kv_cache_groups) == 0:
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
            logger.info("Disabling chunked prefill for model without KVCache")
            vllm_config.scheduler_config.chunked_prefill_enabled = False

126
127
128
129
130
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
        )

131
        self.scheduler: SchedulerInterface = Scheduler(
132
            vllm_config=vllm_config,
133
134
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
135
            include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
136
            log_stats=self.log_stats,
137
            block_size=scheduler_block_size,
138
        )
139
        self.use_spec_decode = vllm_config.speculative_config is not None
140
        if self.scheduler.connector is not None:  # type: ignore
141
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
142

143
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
144
        self.mm_receiver_cache = engine_receiver_cache_from_config(
145
146
            vllm_config, mm_registry
        )
147

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        # If a KV connector is initialized for scheduler, we want to collect
        # handshake metadata from all workers so the connector in the scheduler
        # will have the full context
        kv_connector = self.scheduler.get_kv_connector()
        if kv_connector is not None:
            # Collect and store KV connector xfer metadata from workers
            # (after KV cache registration)
            xfer_handshake_metadata = (
                self.model_executor.get_kv_connector_handshake_metadata()
            )

            if xfer_handshake_metadata:
                # xfer_handshake_metadata is list of dicts from workers
                # Each dict already has structure {tp_rank: metadata}
                # Merge all worker dicts into a single dict
                content: dict[int, Any] = {}
                for worker_dict in xfer_handshake_metadata:
                    if worker_dict is not None:
                        content.update(worker_dict)
                kv_connector.set_xfer_handshake_metadata(content)

169
170
171
172
173
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
174
175
176
        self.batch_queue: (
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None
        ) = None
177
        if self.batch_queue_size > 1:
178
            logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
179
            self.batch_queue = deque(maxlen=self.batch_queue_size)
180

181
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
182
183
        if (
            self.vllm_config.cache_config.enable_prefix_caching
184
            or kv_connector is not None
185
        ):
186
            caching_hash_fn = get_hash_fn_by_name(
187
188
                vllm_config.cache_config.prefix_caching_hash_algo
            )
189
190
191
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
192
                scheduler_block_size, caching_hash_fn
193
            )
194

195
196
197
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
198

199
    def _initialize_kv_caches(
200
201
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
202
        start = time.time()
203

204
        # Get all kv cache needed by the model
205
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
206

207
208
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
209
210
211
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
212
                self.available_gpu_memory_for_kv_cache = (
213
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
214
215
216
217
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
218
219
220
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
221
222
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
223
224
225
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
226

227
        assert len(kv_cache_specs) == len(available_gpu_memory)
228

229
230
231
232
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
233
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
234
        num_cpu_blocks = 0
235
236

        # Initialize kv cache and warmup the execution
237
        self.model_executor.initialize_from_config(kv_cache_configs)
238

239
        elapsed = time.time() - start
240
        logger.info_once(
241
242
            ("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
            elapsed,
243
            scope="local",
244
        )
245
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
246

247
248
249
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

250
251
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
252

253
254
255
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
256
257
258
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
259
260
                f"request_id must be a string, got {type(request.request_id)}"
            )
261

262
        if pooling_params := request.pooling_params:
263
            supported_pooling_tasks = [
264
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
265
266
            ]

267
            if pooling_params.task not in supported_pooling_tasks:
268
269
270
271
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
272

273
        if request.kv_transfer_params is not None and (
274
275
276
277
278
279
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
280

281
        self.scheduler.add_request(request)
282

283
    def abort_requests(self, request_ids: list[str]):
284
285
286
287
288
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
289
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
290

291
292
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
293
        """Execute the model and log detailed info on failure."""
294
        try:
295
            yield
296
297
298
299
300
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

301
            # NOTE: This method is exception-free
302
303
304
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
305
306
            raise err

307
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
308
309
310
311
312
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
313

314
315
316
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
317
            return {}, False
318
        scheduler_output = self.scheduler.schedule()
319
320
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
321
        with self.log_error_detail(scheduler_output):
322
323
324
            model_output = future.result()
            if model_output is None:
                model_output = self.model_executor.sample_tokens(grammar_output)
325

326
        engine_core_outputs = self.scheduler.update_from_output(
327
            scheduler_output, model_output
328
        )
329

330
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
331

332
333
334
335
336
337
338
    def post_step(self, model_executed: bool) -> None:
        if self.use_spec_decode and model_executed:
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

339
    def step_with_batch_queue(
340
        self,
341
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
342
343
344
345
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
346
347
348
349
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
350
351
352
353
354
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
355
356
        batch_queue = self.batch_queue
        assert batch_queue is not None
357

358
359
360
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
361
        assert len(batch_queue) < self.batch_queue_size
362

363
        model_executed = False
364
        deferred_scheduler_output = None
365
366
        if self.scheduler.has_requests():
            scheduler_output = self.scheduler.schedule()
367
368
369
            exec_future = self.model_executor.execute_model(
                scheduler_output, non_block=True
            )
370
            model_executed = scheduler_output.total_num_scheduled_tokens > 0
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404

            if scheduler_output.pending_structured_output_tokens:
                # We need to defer sampling until we have processed the model output
                # from the prior step.
                deferred_scheduler_output = scheduler_output
                # Block-wait for execute to return (continues running async on the GPU).
                with self.log_error_detail(scheduler_output):
                    exec_result = exec_future.result()
                    assert exec_result is None
            else:
                # We aren't waiting for any tokens, get any grammar output immediately.
                grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
                # Block-wait for execute to return (continues running async on the GPU).
                with self.log_error_detail(scheduler_output):
                    exec_result = exec_future.result()

                if exec_result is None:
                    # Call sample tokens.
                    future = self.model_executor.sample_tokens(
                        grammar_output, non_block=True
                    )
                else:
                    # No sampling required (e.g. all requests finished).
                    future = cast(Future[ModelRunnerOutput], exec_future)
                # Add this step's future to the queue.
                batch_queue.appendleft((future, scheduler_output))
                if (
                    model_executed
                    and len(batch_queue) < self.batch_queue_size
                    and not batch_queue[-1][0].done()
                ):
                    # Don't block on next worker response unless the queue is full
                    # or there are no more requests to schedule.
                    return None, True
405
406
407
408
409
410
411
412
413

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False

        # Block until the next result is available.
        future, scheduler_output = batch_queue.pop()
414
415
        with self.log_error_detail(scheduler_output):
            model_output = future.result()
416

417
        engine_core_outputs = self.scheduler.update_from_output(
418
419
            scheduler_output, model_output
        )
420
421
422
423
424
425
426
427
428
429
430
431
432

        # NOTE(nick): We can either handle the deferred tasks here or save
        # in a field and do it immediately once step_with_batch_queue is
        # re-called. The latter slightly favors TTFT over TPOT/throughput.
        if deferred_scheduler_output:
            # We now have the tokens needed to compute the bitmask for the
            # deferred request. Get the bitmask and call sample tokens.
            grammar_output = self.scheduler.get_grammar_bitmask(
                deferred_scheduler_output
            )
            future = self.model_executor.sample_tokens(grammar_output, non_block=True)
            batch_queue.appendleft((future, deferred_scheduler_output))

433
        return engine_core_outputs, model_executed
434

435
    def shutdown(self):
436
        self.structured_output_manager.clear_backend()
437
438
        if self.model_executor:
            self.model_executor.shutdown()
439
440
        if self.scheduler:
            self.scheduler.shutdown()
441

442
    def profile(self, is_start: bool = True):
443
        self.model_executor.profile(is_start)
444

445
446
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
447
        # re-sync the internal caches (P0 sender, P1 receiver)
448
        if self.scheduler.has_unfinished_requests():
449
450
451
452
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
453

454
        # The cache either exists in EngineCore or WorkerWrapperBase
455
456
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
457

458
459
        self.model_executor.reset_mm_cache()

460
461
462
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

463
464
465
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

466
    def wake_up(self, tags: list[str] | None = None):
467
        self.model_executor.wake_up(tags)
468

469
470
471
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

472
    def execute_dummy_batch(self):
473
        self.model_executor.execute_dummy_batch()
474

475
476
477
478
479
480
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

481
    def list_loras(self) -> set[int]:
482
483
484
485
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
486

487
488
489
    def save_sharded_state(
        self,
        path: str,
490
491
        pattern: str | None = None,
        max_size: int | None = None,
492
    ) -> None:
493
494
495
496
497
498
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
499
500
        method: str | Callable[..., _R],
        timeout: float | None = None,
501
        args: tuple = (),
502
        kwargs: dict[str, Any] | None = None,
503
504
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
505

506
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
507
        """Preprocess the request.
508

509
510
511
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
512
513
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
514
        # and will only be accessed in the input processing thread afterwards.
515
        if self.mm_receiver_cache is not None and request.mm_features:
516
517
518
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
519

520
        req = Request.from_engine_core_request(request, self.request_block_hasher)
521
522
523
524
525
526
527
528
529
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

530
531
532
533

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

534
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
535

536
537
    def __init__(
        self,
538
        vllm_config: VllmConfig,
539
        local_client: bool,
540
        handshake_address: str,
541
        executor_class: type[Executor],
542
        log_stats: bool,
543
        client_handshake_address: str | None = None,
544
        engine_index: int = 0,
545
    ):
Rui Qiao's avatar
Rui Qiao committed
546
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
547
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
548
        executor_fail_callback = lambda: self.input_queue.put_nowait(
549
550
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
551

Rui Qiao's avatar
Rui Qiao committed
552
553
554
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
555

556
557
558
559
560
561
562
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
563
            self.client_count = len(addresses.outputs)
564
565

            # Set up data parallel environment.
566
            self.has_coordinator = addresses.coordinator_output is not None
567
            self.frontend_stats_publish_address = (
568
569
570
571
572
573
574
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
575
            # Only publish request queue stats to coordinator for "internal"
576
            # and "hybrid" LB modes .
577
578
            self.publish_dp_lb_stats = (
                self.has_coordinator
579
580
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
581

582
583
            self._init_data_parallel(vllm_config)

584
585
586
            super().__init__(
                vllm_config, executor_class, log_stats, executor_fail_callback
            )
587

588
589
590
591
592
593
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
594
595
596
597
598
599
600
601
602
603
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
604
605
606
607
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
608
609
610
611
612
613
614
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
615
616
617
618
619
620
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
621
                    raise RuntimeError("Input socket thread died during startup")
622
623
624
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

625
626
627
628
629
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        gc.collect()
        gc.freeze()

630
631
632
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()

633
634
635
636
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()

Rui Qiao's avatar
Rui Qiao committed
637
    @contextmanager
638
639
640
641
642
643
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
644
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
645
    ) -> Generator[EngineZmqAddresses, None, None]:
646
647
648
649
650
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

651
        For DP>1 with internal load-balancing this is with the shared front-end
652
653
        process which may reside on a different node.

654
        For DP>1 with external or hybrid load-balancing, two handshakes are
655
        performed:
656
657
658
659
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
660
661
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
662
663
664
665
666
667

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
668
        input_ctx = zmq.Context()
669
        is_local = local_client and client_handshake_address is None
670
        headless = not local_client
671
672
673
674
675
676
677
678
679
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
680
681
682
683
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
684
            assert local_client
685
            local_handshake = self._perform_handshake(
686
687
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
703
        headless: bool,
704
        vllm_config: VllmConfig,
705
        parallel_config_to_update: ParallelConfig | None = None,
706
    ) -> Generator[EngineZmqAddresses, None, None]:
707
708
709
710
711
712
713
714
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
715
            # Register engine with front-end.
716
717
718
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
719
720
721
722
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
723
724
725
726
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
727
728
729
730
731
732
733
734
735
736
737
738

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
739
                )
740
741

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
742

743
    @staticmethod
744
    def startup_handshake(
745
746
        handshake_socket: zmq.Socket,
        local_client: bool,
747
        headless: bool,
748
        parallel_config: ParallelConfig | None = None,
749
    ) -> EngineZmqAddresses:
750
        # Send registration message.
751
        handshake_socket.send(
752
753
754
755
756
757
758
759
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
760
761

        # Receive initialization message.
762
        logger.debug("Waiting for init message from front-end.")
763
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
764
765
766
767
768
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
769
770
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
771
772
            init_bytes, type=EngineHandshakeMetadata
        )
773
774
        logger.debug("Received init message: %s", init_message)

775
776
777
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
778

779
        return init_message.addresses
780
781

    @staticmethod
782
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
783
784
        """Launch EngineCore busy loop in background process."""

785
786
787
788
789
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

790
791
792
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

793
794
795
796
797
798
799
800
801
802
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

803
        engine_core: EngineCoreProc | None = None
804
        try:
805
            parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
806
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
807
                set_process_title("EngineCore", f"DP{dp_rank}")
808
                decorate_logs()
809
810
811
812
813
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
814
                set_process_title("EngineCore")
815
                decorate_logs()
816
817
                engine_core = EngineCoreProc(*args, **kwargs)

818
819
            engine_core.run_busy_loop()

820
        except SystemExit:
821
            logger.debug("EngineCore exiting.")
822
            raise
823
824
825
826
827
828
829
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
830
831
832
833
        finally:
            if engine_core is not None:
                engine_core.shutdown()

834
835
836
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

837
838
839
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

840
841
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
842
            # 1) Poll the input queue until there is work to do.
843
844
845
846
847
848
849
850
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
851
852
853
854
855
        while (
            not self.engines_running
            and not self.scheduler.has_requests()
            and not self.batch_queue
        ):
856
857
858
859
860
861
862
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
863
            logger.debug("EngineCore loop active.")
864
865
866
867
868
869

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

870
    def _process_engine_step(self) -> bool:
871
872
873
        """Called only when there are unfinished local requests."""

        # Step the engine core.
874
        outputs, model_executed = self.step_fn()
875
        # Put EngineCoreOutputs into the output queue.
876
        for output in outputs.items() if outputs else ():
877
            self.output_queue.put_nowait(output)
878
879
        # Post-step hook.
        self.post_step(model_executed)
880

881
882
        return model_executed

883
884
885
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
886
        """Dispatch request from client."""
887

888
        if request_type == EngineCoreRequestType.ADD:
889
890
            req, request_wave = request
            self.add_request(req, request_wave)
891
        elif request_type == EngineCoreRequestType.ABORT:
892
            self.abort_requests(request)
893
        elif request_type == EngineCoreRequestType.UTILITY:
894
            client_idx, call_id, method_name, args = request
895
896
897
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
898
899
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
900
901
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
902
903
904
                output.failure_message = (
                    f"Call to {method_name} method failed: {str(e)}"
                )
905
            self.output_queue.put_nowait(
906
907
                (client_idx, EngineCoreOutputs(utility_output=output))
            )
908
909
910
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
911
912
913
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
914
915
916
917

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
918
        arg type, try converting to msgspec object."""
919
920
921
922
923
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
924
925
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
926
            and issubclass(p.annotation, msgspec.Struct)
927
928
929
930
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
931

932
933
934
935
936
937
938
939
940
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
941
942
943
944
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
945

946
947
948
    def process_input_sockets(
        self,
        input_addresses: list[str],
949
        coord_input_address: str | None,
950
951
952
        identity: bytes,
        ready_event: threading.Event,
    ):
953
954
955
        """Input socket IO thread."""

        # Msgpack serialization decoding.
956
957
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
958

959
960
961
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
962
963
964
965
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
966
967
968
969
970
971
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
972
973
974
975
976
977
978
979
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
980
                # Send subscription message to coordinator.
981
                coord_socket.send(b"\x01")
982
983
984
985
986
987
988

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
989
                input_socket.send(b"")
990
                poller.register(input_socket, zmq.POLLIN)
991

992
            if coord_socket is not None:
993
994
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
995
                poller.register(coord_socket, zmq.POLLIN)
996

997
998
            ready_event.set()
            del ready_event
999
1000
1001
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
1002
1003
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
1004
1005

                    # Deserialize the request data.
1006
1007
1008
1009
1010
                    if request_type == EngineCoreRequestType.ADD:
                        request = add_request_decoder.decode(data_frames)
                        request = self.preprocess_add_request(request)
                    else:
                        request = generic_decoder.decode(data_frames)
1011
1012
1013
1014

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

1015
1016
1017
    def process_output_sockets(
        self,
        output_paths: list[str],
1018
        coord_output_path: str | None,
1019
1020
        engine_index: int,
    ):
1021
1022
1023
        """Output socket IO thread."""

        # Msgpack serialization encoding.
1024
        encoder = MsgpackEncoder()
1025
1026
1027
1028
1029
1030
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
1031

1032
1033
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1034
1035
1036
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1037
1038
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1039
1040
                for output_path in output_paths
            ]
1041
1042
1043
1044
1045
1046
1047
1048
1049
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1050
1051
            max_reuse_bufs = len(sockets) + 1

1052
            while True:
1053
1054
1055
1056
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1057
                    break
1058
1059
                assert not isinstance(output, bytes)
                client_index, outputs = output
1060
                outputs.engine_index = engine_index
1061

1062
1063
1064
1065
1066
1067
1068
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1069
1070
1071
1072
1073
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1074
                buffers = encoder.encode_into(outputs, buffer)
1075
1076
1077
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1078
1079
1080
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1081
1082
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1083
                    reuse_buffers.append(buffer)
1084
1085
1086
1087
1088
1089
1090
1091
1092


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1093
        local_client: bool,
1094
        handshake_address: str,
1095
1096
        executor_class: type[Executor],
        log_stats: bool,
1097
        client_handshake_address: str | None = None,
1098
    ):
1099
1100
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1101
        self.step_counter = 0
1102
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1103
        self.last_counts = (0, 0)
1104
1105
1106

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1107
1108
1109
1110
1111
1112
1113
1114
1115
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
            dp_rank,
        )
1116
1117
1118

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1119
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1120
        dp_size = vllm_config.parallel_config.data_parallel_size
1121
1122
1123
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
1124
        assert local_dp_rank is not None
1125
1126
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1127
1128
1129
1130
1131
1132
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
1133
1134
1135
1136
            logger.debug(
                "Setting kv_transfer_config.engine_id to %s",
                vllm_config.kv_transfer_config.engine_id,
            )
1137

1138
        self.dp_rank = dp_rank
1139
1140
1141
1142
1143
1144
1145
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1146
1147
1148
1149
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1150
1151
1152
1153
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1154
1155
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1156

1157
        super().add_request(request, request_wave)
1158

1159
1160
1161
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1162
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1163
1164
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1165
1166
                new_wave >= self.current_wave
            ):
1167
1168
                self.current_wave = new_wave
                if not self.engines_running:
1169
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1170
1171
1172
1173
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1174
    def _maybe_publish_request_counts(self):
1175
        if not self.publish_dp_lb_stats:
1176
1177
1178
1179
1180
1181
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1182
1183
1184
1185
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1186

1187
1188
1189
1190
1191
1192
1193
1194
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1195
1196
            # 2) Step the engine core.
            executed = self._process_engine_step()
1197
1198
            self._maybe_publish_request_counts()

1199
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1200
1201
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1202
1203
1204
                    # All engines are idle.
                    continue

1205
1206
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1207
1208
1209
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1210
            self.engines_running = self._has_global_unfinished_reqs(
1211
1212
                local_unfinished_reqs
            )
1213

1214
            if not self.engines_running:
1215
                if self.dp_rank == 0 or not self.has_coordinator:
1216
                    # Notify client that we are pausing the loop.
1217
1218
1219
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1220
1221
1222
1223
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1224
                    self.output_queue.put_nowait(
1225
1226
1227
1228
1229
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1230
                # Increment wave count and reset step counter.
1231
                self.current_wave += 1
1232
                self.step_counter = 0
1233
1234

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1235
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1236
1237
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1238
1239
            return True

1240
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1241

1242
    def reinitialize_distributed(
1243
1244
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1245
1246
1247
1248
1249
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1250
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1251
        if reconfig_request.new_data_parallel_rank != -1:
1252
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1253
        # local rank specifies device visibility, it should not be changed
1254
1255
1256
1257
1258
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1259
            reconfig_request.new_data_parallel_master_ip
1260
1261
        )
        parallel_config.data_parallel_master_port = (
1262
            reconfig_request.new_data_parallel_master_port
1263
        )
1264
1265
1266
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1267
        reconfig_request.new_data_parallel_master_port = (
1268
            parallel_config.data_parallel_master_port
1269
        )
1270
1271
1272
1273
1274
1275
1276
1277

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1278
1279
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1280
1281
1282
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1283
1284
1285
1286
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1287
1288
1289
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1290
1291
1292
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1293

Rui Qiao's avatar
Rui Qiao committed
1294
1295
1296
1297
1298
1299
1300
1301
1302

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1303
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1304
1305
1306
1307
1308
1309
1310
1311
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
1312
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1313

1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1324
1325
1326
1327
1328
1329
1330
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1331
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1332

1333
        super().__init__(vllm_config, local_client, "", executor_class, log_stats)
Rui Qiao's avatar
Rui Qiao committed
1334

1335
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1336
        from vllm.platforms import current_platform
1337

1338
1339
1340
1341
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1342
1343
1344
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1345

1346
1347
1348
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1349
1350
1351
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1352
1353
1354
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1355
            os.environ[device_control_env_var] = value
1356
1357
1358
1359
1360
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1361
1362
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1363

Rui Qiao's avatar
Rui Qiao committed
1364
    @contextmanager
1365
1366
1367
1368
1369
1370
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1371
        client_handshake_address: str | None,
1372
    ):
Rui Qiao's avatar
Rui Qiao committed
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()