core.py 55.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
7
import threading
import time
8
from collections import deque
9
from collections.abc import Callable, Generator
10
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
11
from contextlib import ExitStack, contextmanager
12
from inspect import isclass, signature
13
from logging import DEBUG
14
from typing import Any, TypeVar, cast
15

16
import msgspec
17
18
import zmq

19
20
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
21
from vllm.envs import enable_envs_cache
22
from vllm.logger import init_logger
23
from vllm.logging_utils.dump_input import dump_engine_exception
24
from vllm.lora.request import LoRARequest
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.multimodal.cache import engine_receiver_cache_from_config
27
from vllm.tasks import POOLING_TASKS, SupportedTask
28
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
29
30
31
32
from vllm.utils.gc_utils import (
    freeze_gc_heap,
    maybe_attach_gc_debug_callback,
)
33
from vllm.utils.hashing import get_hash_fn_by_name
34
from vllm.utils.network_utils import make_zmq_socket
35
from vllm.utils.system_utils import decorate_logs, set_process_title
36
37
38
39
40
41
42
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
43
from vllm.v1.core.sched.interface import SchedulerInterface
44
from vllm.v1.core.sched.output import SchedulerOutput
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from vllm.v1.engine import (
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
59
from vllm.v1.executor import Executor
60
from vllm.v1.kv_cache_interface import KVCacheConfig
61
from vllm.v1.metrics.stats import SchedulerStats
62
from vllm.v1.outputs import ModelRunnerOutput
63
from vllm.v1.request import Request, RequestStatus
64
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
65
from vllm.v1.structured_output import StructuredOutputManager
66
67
68
69
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

70
POLLING_TIMEOUT_S = 2.5
71
HANDSHAKE_TIMEOUT_MINS = 5
72

73
_R = TypeVar("_R")  # Return type for collective_rpc
74

75
76
77
78

class EngineCore:
    """Inner loop of vLLM's Engine."""

79
80
81
82
83
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
84
        executor_fail_callback: Callable | None = None,
85
    ):
86
87
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
88

89
90
        load_general_plugins()

91
        self.vllm_config = vllm_config
92
        if vllm_config.parallel_config.data_parallel_rank == 0:
93
94
95
96
97
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
98

99
100
        self.log_stats = log_stats

101
102
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
103
        if executor_fail_callback is not None:
104
            self.model_executor.register_failure_callback(executor_fail_callback)
105

106
107
        self.available_gpu_memory_for_kv_cache = -1

108
        # Setup KV Caches and update CacheConfig after profiling.
109
110
111
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
112

113
114
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
115
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
116

117
118
        self.structured_output_manager = StructuredOutputManager(vllm_config)

119
        # Setup scheduler.
120
        Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
121

122
123
124
125
        if len(kv_cache_config.kv_cache_groups) == 0:
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
            logger.info("Disabling chunked prefill for model without KVCache")
126
            vllm_config.scheduler_config.enable_chunked_prefill = False
127

128
129
130
131
132
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
        )

133
        self.scheduler: SchedulerInterface = Scheduler(
134
            vllm_config=vllm_config,
135
136
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
137
            include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
138
            log_stats=self.log_stats,
139
            block_size=scheduler_block_size,
140
        )
141
        self.use_spec_decode = vllm_config.speculative_config is not None
142
        if self.scheduler.connector is not None:  # type: ignore
143
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
144

145
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
146
        self.mm_receiver_cache = engine_receiver_cache_from_config(
147
148
            vllm_config, mm_registry
        )
149

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        # If a KV connector is initialized for scheduler, we want to collect
        # handshake metadata from all workers so the connector in the scheduler
        # will have the full context
        kv_connector = self.scheduler.get_kv_connector()
        if kv_connector is not None:
            # Collect and store KV connector xfer metadata from workers
            # (after KV cache registration)
            xfer_handshake_metadata = (
                self.model_executor.get_kv_connector_handshake_metadata()
            )

            if xfer_handshake_metadata:
                # xfer_handshake_metadata is list of dicts from workers
                # Each dict already has structure {tp_rank: metadata}
                # Merge all worker dicts into a single dict
                content: dict[int, Any] = {}
                for worker_dict in xfer_handshake_metadata:
                    if worker_dict is not None:
                        content.update(worker_dict)
                kv_connector.set_xfer_handshake_metadata(content)

171
172
173
174
175
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
176
177
178
        self.batch_queue: (
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None
        ) = None
179
        if self.batch_queue_size > 1:
180
            logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
181
            self.batch_queue = deque(maxlen=self.batch_queue_size)
182

183
184
185
186
187
        self.ec_producer = (
            vllm_config.ec_transfer_config is not None
            and vllm_config.ec_transfer_config.is_ec_producer
        )

188
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
189
        if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
190
            caching_hash_fn = get_hash_fn_by_name(
191
192
                vllm_config.cache_config.prefix_caching_hash_algo
            )
193
194
195
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
196
                scheduler_block_size, caching_hash_fn
197
            )
198

199
200
201
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
202
        self.async_scheduling = vllm_config.scheduler_config.async_scheduling
203

204
205
206
207
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        freeze_gc_heap()

208
    def _initialize_kv_caches(
209
210
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
211
        start = time.time()
212

213
        # Get all kv cache needed by the model
214
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
215

216
217
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
218
219
220
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
221
                self.available_gpu_memory_for_kv_cache = (
222
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
223
224
225
226
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
227
228
229
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
230
231
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
232
233
234
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
235

236
        assert len(kv_cache_specs) == len(available_gpu_memory)
237

238
239
240
241
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
242
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
243
        num_cpu_blocks = 0
244
245

        # Initialize kv cache and warmup the execution
246
        self.model_executor.initialize_from_config(kv_cache_configs)
247

248
        elapsed = time.time() - start
249
        logger.info_once(
250
            "init engine (profile, create kv cache, warmup model) took %.2f seconds",
251
            elapsed,
252
            scope="local",
253
        )
254
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
255

256
257
258
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

259
260
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
261

262
263
264
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
265
266
267
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
268
269
                f"request_id must be a string, got {type(request.request_id)}"
            )
270

271
        if pooling_params := request.pooling_params:
272
            supported_pooling_tasks = [
273
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
274
275
            ]

276
            if pooling_params.task not in supported_pooling_tasks:
277
278
279
280
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
281

282
        if request.kv_transfer_params is not None and (
283
284
285
286
287
288
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
289

290
        self.scheduler.add_request(request)
291

292
    def abort_requests(self, request_ids: list[str]):
293
294
295
296
297
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
298
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
299

300
301
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
302
        """Execute the model and log detailed info on failure."""
303
        try:
304
            yield
305
306
307
308
309
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

310
            # NOTE: This method is exception-free
311
312
313
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
314
315
            raise err

316
317
318
319
320
321
322
323
324
325
    def _log_err_callback(self, scheduler_output: SchedulerOutput):
        """Log error details of a future that's not expected to return a result."""

        def callback(f, sched_output=scheduler_output):
            with self.log_error_detail(sched_output):
                result = f.result()
                assert result is None

        return callback

326
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
327
328
329
330
331
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
332

333
334
335
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
336
            return {}, False
337
338
339
340
341
342
343
344
345
346
347
        scheduler_output = self.scheduler.schedule()
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
        with self.log_error_detail(scheduler_output):
            model_output = future.result()
            if model_output is None:
                model_output = self.model_executor.sample_tokens(grammar_output)

        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
348

349
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
350

351
    def post_step(self, model_executed: bool) -> None:
352
353
354
355
        # When using async scheduling we can't get draft token ids in advance,
        # so we update draft token ids in the worker process and don't
        # need to update draft token ids here.
        if not self.async_scheduling and self.use_spec_decode and model_executed:
356
357
358
359
360
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

361
    def step_with_batch_queue(
362
        self,
363
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
364
365
366
367
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
368
369
370
371
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
372
373
374
375
376
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
377
378
        batch_queue = self.batch_queue
        assert batch_queue is not None
379

380
381
382
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
383
        assert len(batch_queue) < self.batch_queue_size
384

385
        model_executed = False
386
        deferred_scheduler_output = None
387
        if self.scheduler.has_requests():
388
389
390
391
392
393
            scheduler_output = self.scheduler.schedule()
            exec_future = self.model_executor.execute_model(
                scheduler_output, non_block=True
            )
            if not self.ec_producer:
                model_executed = scheduler_output.total_num_scheduled_tokens > 0
394

395
396
397
            if not model_executed:
                # No sampling required (no requests scheduled).
                future = cast(Future[ModelRunnerOutput], exec_future)
398
            else:
399
400
401
402
403
                exec_future.add_done_callback(self._log_err_callback(scheduler_output))

                if not scheduler_output.pending_structured_output_tokens:
                    # We aren't waiting for any tokens, get any grammar output
                    # and sample immediately.
404
405
406
                    grammar_output = self.scheduler.get_grammar_bitmask(
                        scheduler_output
                    )
407
408
409
                    future = self.model_executor.sample_tokens(
                        grammar_output, non_block=True
                    )
410
                else:
411
412
413
414
415
                    # We need to defer sampling until we have processed the model output
                    # from the prior step.
                    deferred_scheduler_output = scheduler_output

            if not deferred_scheduler_output:
416
417
418
419
420
421
422
423
424
425
                # Add this step's future to the queue.
                batch_queue.appendleft((future, scheduler_output))
                if (
                    model_executed
                    and len(batch_queue) < self.batch_queue_size
                    and not batch_queue[-1][0].done()
                ):
                    # Don't block on next worker response unless the queue is full
                    # or there are no more requests to schedule.
                    return None, True
426
427
428
429
430
431

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False
432
433
434
435
436
437
438
439
440

        # Block until the next result is available.
        future, scheduler_output = batch_queue.pop()
        with self.log_error_detail(scheduler_output):
            model_output = future.result()

        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
441
442
443
444
445

        # NOTE(nick): We can either handle the deferred tasks here or save
        # in a field and do it immediately once step_with_batch_queue is
        # re-called. The latter slightly favors TTFT over TPOT/throughput.
        if deferred_scheduler_output:
446
447
448
449
450
451
452
            # We now have the tokens needed to compute the bitmask for the
            # deferred request. Get the bitmask and call sample tokens.
            grammar_output = self.scheduler.get_grammar_bitmask(
                deferred_scheduler_output
            )
            future = self.model_executor.sample_tokens(grammar_output, non_block=True)
            batch_queue.appendleft((future, deferred_scheduler_output))
453

454
        return engine_core_outputs, model_executed
455

456
    def shutdown(self):
457
        self.structured_output_manager.clear_backend()
458
459
        if self.model_executor:
            self.model_executor.shutdown()
460
461
        if self.scheduler:
            self.scheduler.shutdown()
462

463
    def profile(self, is_start: bool = True):
464
        self.model_executor.profile(is_start)
465

466
467
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
468
        # re-sync the internal caches (P0 sender, P1 receiver)
469
        if self.scheduler.has_unfinished_requests():
470
471
472
473
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
474

475
        # The cache either exists in EngineCore or WorkerWrapperBase
476
477
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
478

479
480
        self.model_executor.reset_mm_cache()

481
482
483
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

484
485
486
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

487
    def wake_up(self, tags: list[str] | None = None):
488
        self.model_executor.wake_up(tags)
489

490
491
492
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

493
    def execute_dummy_batch(self):
494
        self.model_executor.execute_dummy_batch()
495

496
497
498
499
500
501
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

502
    def list_loras(self) -> set[int]:
503
504
505
506
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
507

508
509
510
    def save_sharded_state(
        self,
        path: str,
511
512
        pattern: str | None = None,
        max_size: int | None = None,
513
    ) -> None:
514
515
516
517
518
519
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
520
521
        method: str | Callable[..., _R],
        timeout: float | None = None,
522
        args: tuple = (),
523
        kwargs: dict[str, Any] | None = None,
524
525
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
526

527
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
528
        """Preprocess the request.
529

530
531
532
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
533
534
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
535
        # and will only be accessed in the input processing thread afterwards.
536
        if self.mm_receiver_cache is not None and request.mm_features:
537
538
539
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
540

541
        req = Request.from_engine_core_request(request, self.request_block_hasher)
542
543
544
545
546
547
548
549
550
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

551
552
553
554

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

555
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
556

557
558
    def __init__(
        self,
559
        vllm_config: VllmConfig,
560
        local_client: bool,
561
        handshake_address: str,
562
        executor_class: type[Executor],
563
        log_stats: bool,
564
        client_handshake_address: str | None = None,
565
        engine_index: int = 0,
566
    ):
Rui Qiao's avatar
Rui Qiao committed
567
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
568
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
569
        executor_fail_callback = lambda: self.input_queue.put_nowait(
570
571
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
572

Rui Qiao's avatar
Rui Qiao committed
573
574
575
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
576

577
578
579
580
581
582
583
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
584
            self.client_count = len(addresses.outputs)
585
586

            # Set up data parallel environment.
587
            self.has_coordinator = addresses.coordinator_output is not None
588
            self.frontend_stats_publish_address = (
589
590
591
592
593
594
595
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
596
            # Only publish request queue stats to coordinator for "internal"
597
            # and "hybrid" LB modes .
598
599
            self.publish_dp_lb_stats = (
                self.has_coordinator
600
601
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
602

603
604
            self._init_data_parallel(vllm_config)

605
606
607
            super().__init__(
                vllm_config, executor_class, log_stats, executor_fail_callback
            )
608

609
610
611
612
613
614
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
615
616
617
618
619
620
621
622
623
624
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
625
626
627
628
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
629
630
631
632
633
634
635
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
636
637
638
639
640
641
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
642
                    raise RuntimeError("Input socket thread died during startup")
643
644
645
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

646
647
648
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()

649
650
651
652
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()

Rui Qiao's avatar
Rui Qiao committed
653
    @contextmanager
654
655
656
657
658
659
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
660
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
661
    ) -> Generator[EngineZmqAddresses, None, None]:
662
663
664
665
666
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

667
        For DP>1 with internal load-balancing this is with the shared front-end
668
669
        process which may reside on a different node.

670
        For DP>1 with external or hybrid load-balancing, two handshakes are
671
        performed:
672
673
674
675
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
676
677
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
678
679
680
681
682
683

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
684
        input_ctx = zmq.Context()
685
        is_local = local_client and client_handshake_address is None
686
        headless = not local_client
687
688
689
690
691
692
693
694
695
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
696
697
698
699
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
700
            assert local_client
701
            local_handshake = self._perform_handshake(
702
703
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
719
        headless: bool,
720
        vllm_config: VllmConfig,
721
        parallel_config_to_update: ParallelConfig | None = None,
722
    ) -> Generator[EngineZmqAddresses, None, None]:
723
724
725
726
727
728
729
730
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
731
            # Register engine with front-end.
732
733
734
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
735
736
737
738
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
739
740
741
742
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
743
744
745
746
747
748
749
750
751
752
753
754

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
755
                )
756
757

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
758

759
    @staticmethod
760
    def startup_handshake(
761
762
        handshake_socket: zmq.Socket,
        local_client: bool,
763
        headless: bool,
764
        parallel_config: ParallelConfig | None = None,
765
    ) -> EngineZmqAddresses:
766
        # Send registration message.
767
        handshake_socket.send(
768
769
770
771
772
773
774
775
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
776
777

        # Receive initialization message.
778
        logger.debug("Waiting for init message from front-end.")
779
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
780
781
782
783
784
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
785
786
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
787
788
            init_bytes, type=EngineHandshakeMetadata
        )
789
790
        logger.debug("Received init message: %s", init_message)

791
792
793
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
794

795
        return init_message.addresses
796
797

    @staticmethod
798
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
799
800
        """Launch EngineCore busy loop in background process."""

801
802
803
804
805
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

806
807
808
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

809
810
811
812
813
814
815
816
817
818
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

819
        engine_core: EngineCoreProc | None = None
820
        try:
821
            parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
822
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
823
                set_process_title("EngineCore", f"DP{dp_rank}")
824
                decorate_logs()
825
826
827
828
829
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
830
                set_process_title("EngineCore")
831
                decorate_logs()
832
833
                engine_core = EngineCoreProc(*args, **kwargs)

834
835
            engine_core.run_busy_loop()

836
        except SystemExit:
837
            logger.debug("EngineCore exiting.")
838
            raise
839
840
841
842
843
844
845
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
846
847
848
849
        finally:
            if engine_core is not None:
                engine_core.shutdown()

850
851
852
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

853
854
855
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

856
857
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
858
            # 1) Poll the input queue until there is work to do.
859
860
861
862
863
864
865
866
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
867
868
869
870
871
        while (
            not self.engines_running
            and not self.scheduler.has_requests()
            and not self.batch_queue
        ):
872
873
874
875
876
877
878
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
879
            logger.debug("EngineCore loop active.")
880
881
882
883
884
885

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

886
    def _process_engine_step(self) -> bool:
887
888
889
        """Called only when there are unfinished local requests."""

        # Step the engine core.
890
        outputs, model_executed = self.step_fn()
891
        # Put EngineCoreOutputs into the output queue.
892
        for output in outputs.items() if outputs else ():
893
            self.output_queue.put_nowait(output)
894
895
        # Post-step hook.
        self.post_step(model_executed)
896

897
898
        return model_executed

899
900
901
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
902
        """Dispatch request from client."""
903

904
        if request_type == EngineCoreRequestType.ADD:
905
906
            req, request_wave = request
            self.add_request(req, request_wave)
907
        elif request_type == EngineCoreRequestType.ABORT:
908
            self.abort_requests(request)
909
        elif request_type == EngineCoreRequestType.UTILITY:
910
            client_idx, call_id, method_name, args = request
911
912
913
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
914
915
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
916
917
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
918
919
920
                output.failure_message = (
                    f"Call to {method_name} method failed: {str(e)}"
                )
921
            self.output_queue.put_nowait(
922
923
                (client_idx, EngineCoreOutputs(utility_output=output))
            )
924
925
926
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
927
928
929
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
930
931
932
933

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
934
        arg type, try converting to msgspec object."""
935
936
937
938
939
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
940
941
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
942
            and issubclass(p.annotation, msgspec.Struct)
943
944
945
946
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
947

948
949
950
951
952
953
954
955
956
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
957
958
959
960
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
961

962
963
964
    def process_input_sockets(
        self,
        input_addresses: list[str],
965
        coord_input_address: str | None,
966
967
968
        identity: bytes,
        ready_event: threading.Event,
    ):
969
970
971
        """Input socket IO thread."""

        # Msgpack serialization decoding.
972
973
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
974

975
976
977
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
978
979
980
981
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
982
983
984
985
986
987
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
988
989
990
991
992
993
994
995
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
996
                # Send subscription message to coordinator.
997
                coord_socket.send(b"\x01")
998
999
1000
1001
1002
1003
1004

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
1005
                input_socket.send(b"")
1006
                poller.register(input_socket, zmq.POLLIN)
1007

1008
            if coord_socket is not None:
1009
1010
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
1011
                poller.register(coord_socket, zmq.POLLIN)
1012

1013
1014
            ready_event.set()
            del ready_event
1015
1016
1017
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
1018
1019
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
1020
1021

                    # Deserialize the request data.
1022
1023
1024
1025
1026
                    if request_type == EngineCoreRequestType.ADD:
                        request = add_request_decoder.decode(data_frames)
                        request = self.preprocess_add_request(request)
                    else:
                        request = generic_decoder.decode(data_frames)
1027
1028
1029
1030

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

1031
1032
1033
    def process_output_sockets(
        self,
        output_paths: list[str],
1034
        coord_output_path: str | None,
1035
1036
        engine_index: int,
    ):
1037
1038
1039
        """Output socket IO thread."""

        # Msgpack serialization encoding.
1040
        encoder = MsgpackEncoder()
1041
1042
1043
1044
1045
1046
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
1047

1048
1049
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1050
1051
1052
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1053
1054
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1055
1056
                for output_path in output_paths
            ]
1057
1058
1059
1060
1061
1062
1063
1064
1065
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1066
1067
            max_reuse_bufs = len(sockets) + 1

1068
            while True:
1069
1070
1071
1072
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1073
                    break
1074
1075
                assert not isinstance(output, bytes)
                client_index, outputs = output
1076
                outputs.engine_index = engine_index
1077

1078
1079
1080
1081
1082
1083
1084
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1085
1086
1087
1088
1089
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1090
                buffers = encoder.encode_into(outputs, buffer)
1091
1092
1093
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1094
1095
1096
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1097
1098
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1099
                    reuse_buffers.append(buffer)
1100
1101
1102
1103
1104
1105
1106
1107
1108


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1109
        local_client: bool,
1110
        handshake_address: str,
1111
1112
        executor_class: type[Executor],
        log_stats: bool,
1113
        client_handshake_address: str | None = None,
1114
    ):
1115
1116
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1117
        self.step_counter = 0
1118
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1119
        self.last_counts = (0, 0)
1120
1121
1122

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1123
1124
1125
1126
1127
1128
1129
1130
1131
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
            dp_rank,
        )
1132
1133
1134

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1135
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1136
        dp_size = vllm_config.parallel_config.data_parallel_size
1137
1138
1139
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
1140
        assert local_dp_rank is not None
1141
1142
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1143
1144
1145
1146
1147
1148
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
1149
1150
1151
1152
            logger.debug(
                "Setting kv_transfer_config.engine_id to %s",
                vllm_config.kv_transfer_config.engine_id,
            )
1153

1154
        self.dp_rank = dp_rank
1155
1156
1157
1158
1159
1160
1161
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1162
1163
1164
1165
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1166
1167
1168
1169
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1170
1171
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1172

1173
        super().add_request(request, request_wave)
1174

1175
1176
1177
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1178
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1179
1180
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1181
1182
                new_wave >= self.current_wave
            ):
1183
1184
                self.current_wave = new_wave
                if not self.engines_running:
1185
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1186
1187
1188
1189
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1190
    def _maybe_publish_request_counts(self):
1191
        if not self.publish_dp_lb_stats:
1192
1193
1194
1195
1196
1197
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1198
1199
1200
1201
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1202

1203
1204
1205
1206
1207
1208
1209
1210
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1211
1212
            # 2) Step the engine core.
            executed = self._process_engine_step()
1213
1214
            self._maybe_publish_request_counts()

1215
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1216
1217
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1218
1219
1220
                    # All engines are idle.
                    continue

1221
1222
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1223
1224
1225
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1226
            self.engines_running = self._has_global_unfinished_reqs(
1227
1228
                local_unfinished_reqs
            )
1229

1230
            if not self.engines_running:
1231
                if self.dp_rank == 0 or not self.has_coordinator:
1232
                    # Notify client that we are pausing the loop.
1233
1234
1235
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1236
1237
1238
1239
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1240
                    self.output_queue.put_nowait(
1241
1242
1243
1244
1245
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1246
                # Increment wave count and reset step counter.
1247
                self.current_wave += 1
1248
                self.step_counter = 0
1249
1250

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1251
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1252
1253
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1254
1255
            return True

1256
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1257

1258
    def reinitialize_distributed(
1259
1260
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1261
1262
1263
1264
1265
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1266
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1267
        if reconfig_request.new_data_parallel_rank != -1:
1268
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1269
        # local rank specifies device visibility, it should not be changed
1270
1271
1272
1273
1274
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1275
            reconfig_request.new_data_parallel_master_ip
1276
1277
        )
        parallel_config.data_parallel_master_port = (
1278
            reconfig_request.new_data_parallel_master_port
1279
        )
1280
1281
1282
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1283
        reconfig_request.new_data_parallel_master_port = (
1284
            parallel_config.data_parallel_master_port
1285
        )
1286
1287
1288
1289
1290
1291
1292
1293

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1294
1295
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1296
1297
1298
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1299
1300
1301
1302
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1303
1304
1305
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1306
1307
1308
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1309

Rui Qiao's avatar
Rui Qiao committed
1310
1311
1312
1313
1314
1315
1316
1317
1318

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1319
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1320
1321
1322
1323
1324
1325
1326
1327
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
1328
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1329

1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1340
1341
1342
1343
1344
1345
1346
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1347
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1348

1349
        super().__init__(vllm_config, local_client, "", executor_class, log_stats)
Rui Qiao's avatar
Rui Qiao committed
1350

1351
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1352
        from vllm.platforms import current_platform
1353

1354
1355
1356
1357
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1358
1359
1360
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1361

1362
1363
1364
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1365
1366
1367
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1368
1369
1370
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1371
            os.environ[device_control_env_var] = value
1372
1373
1374
1375
1376
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1377
1378
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1379

Rui Qiao's avatar
Rui Qiao committed
1380
    @contextmanager
1381
1382
1383
1384
1385
1386
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1387
        client_handshake_address: str | None,
1388
    ):
Rui Qiao's avatar
Rui Qiao committed
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()