core.py 55.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
7
import threading
import time
8
from collections import deque
9
from collections.abc import Callable, Generator
10
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
11
from contextlib import ExitStack, contextmanager
12
from inspect import isclass, signature
13
from logging import DEBUG
14
from typing import Any, TypeVar, cast
15

16
import msgspec
17
18
import zmq

19
20
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
21
from vllm.envs import enable_envs_cache
22
from vllm.logger import init_logger
23
from vllm.logging_utils.dump_input import dump_engine_exception
24
from vllm.lora.request import LoRARequest
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.multimodal.cache import engine_receiver_cache_from_config
27
from vllm.tasks import POOLING_TASKS, SupportedTask
28
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
29
30
31
32
from vllm.utils.gc_utils import (
    freeze_gc_heap,
    maybe_attach_gc_debug_callback,
)
33
from vllm.utils.hashing import get_hash_fn_by_name
34
from vllm.utils.network_utils import make_zmq_socket
35
from vllm.utils.system_utils import decorate_logs, set_process_title
36
37
38
39
40
41
42
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
43
from vllm.v1.core.sched.interface import SchedulerInterface
44
from vllm.v1.core.sched.output import SchedulerOutput
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from vllm.v1.engine import (
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
59
from vllm.v1.executor import Executor
60
from vllm.v1.kv_cache_interface import KVCacheConfig
61
from vllm.v1.metrics.stats import SchedulerStats
62
from vllm.v1.outputs import ModelRunnerOutput
63
from vllm.v1.request import Request, RequestStatus
64
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
65
from vllm.v1.structured_output import StructuredOutputManager
66
67
68
69
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

70
POLLING_TIMEOUT_S = 2.5
71
HANDSHAKE_TIMEOUT_MINS = 5
72

73
_R = TypeVar("_R")  # Return type for collective_rpc
74

75
76
77
78

class EngineCore:
    """Inner loop of vLLM's Engine."""

79
80
81
82
83
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
84
        executor_fail_callback: Callable | None = None,
85
    ):
86
87
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
88

89
90
        load_general_plugins()

91
        self.vllm_config = vllm_config
92
        if vllm_config.parallel_config.data_parallel_rank == 0:
93
94
95
96
97
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
98

99
100
        self.log_stats = log_stats

101
102
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
103
        if executor_fail_callback is not None:
104
            self.model_executor.register_failure_callback(executor_fail_callback)
105

106
107
        self.available_gpu_memory_for_kv_cache = -1

108
        # Setup KV Caches and update CacheConfig after profiling.
109
110
111
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
112

113
114
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
115
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
116

117
118
        self.structured_output_manager = StructuredOutputManager(vllm_config)

119
        # Setup scheduler.
120
        Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
121

122
123
124
125
        if len(kv_cache_config.kv_cache_groups) == 0:
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
            logger.info("Disabling chunked prefill for model without KVCache")
126
            vllm_config.scheduler_config.enable_chunked_prefill = False
127

128
129
130
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
131
            * vllm_config.parallel_config.prefill_context_parallel_size
132
133
        )

134
        self.scheduler: SchedulerInterface = Scheduler(
135
            vllm_config=vllm_config,
136
137
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
138
            include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
139
            log_stats=self.log_stats,
140
            block_size=scheduler_block_size,
141
        )
142
        self.use_spec_decode = vllm_config.speculative_config is not None
143
        if self.scheduler.connector is not None:  # type: ignore
144
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
145

146
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
147
        self.mm_receiver_cache = engine_receiver_cache_from_config(
148
149
            vllm_config, mm_registry
        )
150

151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        # If a KV connector is initialized for scheduler, we want to collect
        # handshake metadata from all workers so the connector in the scheduler
        # will have the full context
        kv_connector = self.scheduler.get_kv_connector()
        if kv_connector is not None:
            # Collect and store KV connector xfer metadata from workers
            # (after KV cache registration)
            xfer_handshake_metadata = (
                self.model_executor.get_kv_connector_handshake_metadata()
            )

            if xfer_handshake_metadata:
                # xfer_handshake_metadata is list of dicts from workers
                # Each dict already has structure {tp_rank: metadata}
                # Merge all worker dicts into a single dict
                content: dict[int, Any] = {}
                for worker_dict in xfer_handshake_metadata:
                    if worker_dict is not None:
                        content.update(worker_dict)
                kv_connector.set_xfer_handshake_metadata(content)

172
173
174
175
176
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
177
178
179
        self.batch_queue: (
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None
        ) = None
180
        if self.batch_queue_size > 1:
181
            logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
182
            self.batch_queue = deque(maxlen=self.batch_queue_size)
183

184
        self.is_ec_producer = (
185
186
187
            vllm_config.ec_transfer_config is not None
            and vllm_config.ec_transfer_config.is_ec_producer
        )
188
        self.is_pooling_model = vllm_config.model_config.runner_type == "pooling"
189

190
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
191
        if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
192
            caching_hash_fn = get_hash_fn_by_name(
193
194
                vllm_config.cache_config.prefix_caching_hash_algo
            )
195
196
197
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
198
                scheduler_block_size, caching_hash_fn
199
            )
200

201
202
203
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
204
        self.async_scheduling = vllm_config.scheduler_config.async_scheduling
205

206
207
208
209
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        freeze_gc_heap()

210
    def _initialize_kv_caches(
211
212
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
213
        start = time.time()
214

215
        # Get all kv cache needed by the model
216
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
217

218
219
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
220
221
222
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
223
                self.available_gpu_memory_for_kv_cache = (
224
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
225
226
227
228
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
229
230
231
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
232
233
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
234
235
236
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
237

238
        assert len(kv_cache_specs) == len(available_gpu_memory)
239

240
241
242
243
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
244
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
245
        num_cpu_blocks = 0
246
247

        # Initialize kv cache and warmup the execution
248
        self.model_executor.initialize_from_config(kv_cache_configs)
249

250
        elapsed = time.time() - start
251
        logger.info_once(
252
            "init engine (profile, create kv cache, warmup model) took %.2f seconds",
253
            elapsed,
254
            scope="local",
255
        )
256
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
257

258
259
260
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

261
262
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
263

264
265
266
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
267
268
269
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
270
271
                f"request_id must be a string, got {type(request.request_id)}"
            )
272

273
        if pooling_params := request.pooling_params:
274
            supported_pooling_tasks = [
275
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
276
277
            ]

278
            if pooling_params.task not in supported_pooling_tasks:
279
280
281
282
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
283

284
        if request.kv_transfer_params is not None and (
285
286
287
288
289
290
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
291

292
        self.scheduler.add_request(request)
293

294
    def abort_requests(self, request_ids: list[str]):
295
296
297
298
299
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
300
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
301

302
303
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
304
        """Execute the model and log detailed info on failure."""
305
        try:
306
            yield
307
308
309
310
311
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

312
            # NOTE: This method is exception-free
313
314
315
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
316
317
            raise err

318
319
320
321
322
323
324
325
326
327
    def _log_err_callback(self, scheduler_output: SchedulerOutput):
        """Log error details of a future that's not expected to return a result."""

        def callback(f, sched_output=scheduler_output):
            with self.log_error_detail(sched_output):
                result = f.result()
                assert result is None

        return callback

328
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
329
330
331
332
333
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
334

335
336
337
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
338
            return {}, False
339
340
341
342
343
344
345
346
347
348
349
        scheduler_output = self.scheduler.schedule()
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
        with self.log_error_detail(scheduler_output):
            model_output = future.result()
            if model_output is None:
                model_output = self.model_executor.sample_tokens(grammar_output)

        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
350

351
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
352

353
    def post_step(self, model_executed: bool) -> None:
354
355
356
357
        # When using async scheduling we can't get draft token ids in advance,
        # so we update draft token ids in the worker process and don't
        # need to update draft token ids here.
        if not self.async_scheduling and self.use_spec_decode and model_executed:
358
359
360
361
362
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

363
    def step_with_batch_queue(
364
        self,
365
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
366
367
368
369
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
370
371
372
373
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
374
375
376
377
378
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
379
380
        batch_queue = self.batch_queue
        assert batch_queue is not None
381

382
383
384
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
385
        assert len(batch_queue) < self.batch_queue_size
386

387
        model_executed = False
388
        deferred_scheduler_output = None
389
        if self.scheduler.has_requests():
390
391
392
393
            scheduler_output = self.scheduler.schedule()
            exec_future = self.model_executor.execute_model(
                scheduler_output, non_block=True
            )
394
            if not self.is_ec_producer:
395
                model_executed = scheduler_output.total_num_scheduled_tokens > 0
396

397
            if self.is_pooling_model or not model_executed:
398
399
                # No sampling required (no requests scheduled).
                future = cast(Future[ModelRunnerOutput], exec_future)
400
            else:
401
402
403
404
405
                exec_future.add_done_callback(self._log_err_callback(scheduler_output))

                if not scheduler_output.pending_structured_output_tokens:
                    # We aren't waiting for any tokens, get any grammar output
                    # and sample immediately.
406
407
408
                    grammar_output = self.scheduler.get_grammar_bitmask(
                        scheduler_output
                    )
409
410
411
                    future = self.model_executor.sample_tokens(
                        grammar_output, non_block=True
                    )
412
                else:
413
414
415
416
417
                    # We need to defer sampling until we have processed the model output
                    # from the prior step.
                    deferred_scheduler_output = scheduler_output

            if not deferred_scheduler_output:
418
419
420
421
422
423
424
425
426
427
                # Add this step's future to the queue.
                batch_queue.appendleft((future, scheduler_output))
                if (
                    model_executed
                    and len(batch_queue) < self.batch_queue_size
                    and not batch_queue[-1][0].done()
                ):
                    # Don't block on next worker response unless the queue is full
                    # or there are no more requests to schedule.
                    return None, True
428
429
430
431
432
433

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False
434
435
436
437
438
439
440
441
442

        # Block until the next result is available.
        future, scheduler_output = batch_queue.pop()
        with self.log_error_detail(scheduler_output):
            model_output = future.result()

        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
443
444
445
446
447

        # NOTE(nick): We can either handle the deferred tasks here or save
        # in a field and do it immediately once step_with_batch_queue is
        # re-called. The latter slightly favors TTFT over TPOT/throughput.
        if deferred_scheduler_output:
448
449
450
451
452
453
454
            # We now have the tokens needed to compute the bitmask for the
            # deferred request. Get the bitmask and call sample tokens.
            grammar_output = self.scheduler.get_grammar_bitmask(
                deferred_scheduler_output
            )
            future = self.model_executor.sample_tokens(grammar_output, non_block=True)
            batch_queue.appendleft((future, deferred_scheduler_output))
455

456
        return engine_core_outputs, model_executed
457

458
    def shutdown(self):
459
        self.structured_output_manager.clear_backend()
460
461
        if self.model_executor:
            self.model_executor.shutdown()
462
463
        if self.scheduler:
            self.scheduler.shutdown()
464

465
    def profile(self, is_start: bool = True):
466
        self.model_executor.profile(is_start)
467

468
469
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
470
        # re-sync the internal caches (P0 sender, P1 receiver)
471
        if self.scheduler.has_unfinished_requests():
472
473
474
475
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
476

477
        # The cache either exists in EngineCore or WorkerWrapperBase
478
479
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
480

481
482
        self.model_executor.reset_mm_cache()

483
484
485
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

486
487
488
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

489
    def wake_up(self, tags: list[str] | None = None):
490
        self.model_executor.wake_up(tags)
491

492
493
494
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

495
    def execute_dummy_batch(self):
496
        self.model_executor.execute_dummy_batch()
497

498
499
500
501
502
503
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

504
    def list_loras(self) -> set[int]:
505
506
507
508
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
509

510
511
512
    def save_sharded_state(
        self,
        path: str,
513
514
        pattern: str | None = None,
        max_size: int | None = None,
515
    ) -> None:
516
517
518
519
520
521
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
522
523
        method: str | Callable[..., _R],
        timeout: float | None = None,
524
        args: tuple = (),
525
        kwargs: dict[str, Any] | None = None,
526
527
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
528

529
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
530
        """Preprocess the request.
531

532
533
534
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
535
536
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
537
        # and will only be accessed in the input processing thread afterwards.
538
        if self.mm_receiver_cache is not None and request.mm_features:
539
540
541
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
542

543
        req = Request.from_engine_core_request(request, self.request_block_hasher)
544
545
546
547
548
549
550
551
552
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

553
554
555
556

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

557
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
558

559
560
    def __init__(
        self,
561
        vllm_config: VllmConfig,
562
        local_client: bool,
563
        handshake_address: str,
564
        executor_class: type[Executor],
565
        log_stats: bool,
566
        client_handshake_address: str | None = None,
567
        engine_index: int = 0,
568
    ):
Rui Qiao's avatar
Rui Qiao committed
569
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
570
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
571
        executor_fail_callback = lambda: self.input_queue.put_nowait(
572
573
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
574

Rui Qiao's avatar
Rui Qiao committed
575
576
577
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
578

579
580
581
582
583
584
585
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
586
            self.client_count = len(addresses.outputs)
587
588

            # Set up data parallel environment.
589
            self.has_coordinator = addresses.coordinator_output is not None
590
            self.frontend_stats_publish_address = (
591
592
593
594
595
596
597
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
598
            # Only publish request queue stats to coordinator for "internal"
599
            # and "hybrid" LB modes .
600
601
            self.publish_dp_lb_stats = (
                self.has_coordinator
602
603
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
604

605
606
            self._init_data_parallel(vllm_config)

607
608
609
            super().__init__(
                vllm_config, executor_class, log_stats, executor_fail_callback
            )
610

611
612
613
614
615
616
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
617
618
619
620
621
622
623
624
625
626
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
627
628
629
630
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
631
632
633
634
635
636
637
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
638
639
640
641
642
643
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
644
                    raise RuntimeError("Input socket thread died during startup")
645
646
647
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

648
649
650
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()

651
652
653
654
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()

Rui Qiao's avatar
Rui Qiao committed
655
    @contextmanager
656
657
658
659
660
661
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
662
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
663
    ) -> Generator[EngineZmqAddresses, None, None]:
664
665
666
667
668
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

669
        For DP>1 with internal load-balancing this is with the shared front-end
670
671
        process which may reside on a different node.

672
        For DP>1 with external or hybrid load-balancing, two handshakes are
673
        performed:
674
675
676
677
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
678
679
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
680
681
682
683
684
685

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
686
        input_ctx = zmq.Context()
687
        is_local = local_client and client_handshake_address is None
688
        headless = not local_client
689
690
691
692
693
694
695
696
697
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
698
699
700
701
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
702
            assert local_client
703
            local_handshake = self._perform_handshake(
704
705
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
721
        headless: bool,
722
        vllm_config: VllmConfig,
723
        parallel_config_to_update: ParallelConfig | None = None,
724
    ) -> Generator[EngineZmqAddresses, None, None]:
725
726
727
728
729
730
731
732
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
733
            # Register engine with front-end.
734
735
736
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
737
738
739
740
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
741
742
743
744
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
745
746
747
748
749
750
751
752
753
754
755
756

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
757
                )
758
759

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
760

761
    @staticmethod
762
    def startup_handshake(
763
764
        handshake_socket: zmq.Socket,
        local_client: bool,
765
        headless: bool,
766
        parallel_config: ParallelConfig | None = None,
767
    ) -> EngineZmqAddresses:
768
        # Send registration message.
769
        handshake_socket.send(
770
771
772
773
774
775
776
777
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
778
779

        # Receive initialization message.
780
        logger.debug("Waiting for init message from front-end.")
781
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
782
783
784
785
786
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
787
788
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
789
790
            init_bytes, type=EngineHandshakeMetadata
        )
791
792
        logger.debug("Received init message: %s", init_message)

793
794
795
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
796

797
        return init_message.addresses
798
799

    @staticmethod
800
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
801
802
        """Launch EngineCore busy loop in background process."""

803
804
805
806
807
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

808
809
810
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

811
812
813
814
815
816
817
818
819
820
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

821
        engine_core: EngineCoreProc | None = None
822
        try:
823
            parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
824
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
825
                set_process_title("EngineCore", f"DP{dp_rank}")
826
                decorate_logs()
827
828
829
830
831
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
832
                set_process_title("EngineCore")
833
                decorate_logs()
834
835
                engine_core = EngineCoreProc(*args, **kwargs)

836
837
            engine_core.run_busy_loop()

838
        except SystemExit:
839
            logger.debug("EngineCore exiting.")
840
            raise
841
842
843
844
845
846
847
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
848
849
850
851
        finally:
            if engine_core is not None:
                engine_core.shutdown()

852
853
854
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

855
856
857
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

858
859
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
860
            # 1) Poll the input queue until there is work to do.
861
862
863
864
865
866
867
868
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
869
870
871
872
873
        while (
            not self.engines_running
            and not self.scheduler.has_requests()
            and not self.batch_queue
        ):
874
875
876
877
878
879
880
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
881
            logger.debug("EngineCore loop active.")
882
883
884
885
886
887

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

888
    def _process_engine_step(self) -> bool:
889
890
891
        """Called only when there are unfinished local requests."""

        # Step the engine core.
892
        outputs, model_executed = self.step_fn()
893
        # Put EngineCoreOutputs into the output queue.
894
        for output in outputs.items() if outputs else ():
895
            self.output_queue.put_nowait(output)
896
897
        # Post-step hook.
        self.post_step(model_executed)
898

899
900
        return model_executed

901
902
903
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
904
        """Dispatch request from client."""
905

906
        if request_type == EngineCoreRequestType.ADD:
907
908
            req, request_wave = request
            self.add_request(req, request_wave)
909
        elif request_type == EngineCoreRequestType.ABORT:
910
            self.abort_requests(request)
911
        elif request_type == EngineCoreRequestType.UTILITY:
912
            client_idx, call_id, method_name, args = request
913
914
915
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
916
917
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
918
919
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
920
921
922
                output.failure_message = (
                    f"Call to {method_name} method failed: {str(e)}"
                )
923
            self.output_queue.put_nowait(
924
925
                (client_idx, EngineCoreOutputs(utility_output=output))
            )
926
927
928
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
929
930
931
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
932
933
934
935

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
936
        arg type, try converting to msgspec object."""
937
938
939
940
941
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
942
943
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
944
            and issubclass(p.annotation, msgspec.Struct)
945
946
947
948
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
949

950
951
952
953
954
955
956
957
958
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
959
960
961
962
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
963

964
965
966
    def process_input_sockets(
        self,
        input_addresses: list[str],
967
        coord_input_address: str | None,
968
969
970
        identity: bytes,
        ready_event: threading.Event,
    ):
971
972
973
        """Input socket IO thread."""

        # Msgpack serialization decoding.
974
975
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
976

977
978
979
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
980
981
982
983
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
984
985
986
987
988
989
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
990
991
992
993
994
995
996
997
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
998
                # Send subscription message to coordinator.
999
                coord_socket.send(b"\x01")
1000
1001
1002
1003
1004
1005
1006

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
1007
                input_socket.send(b"")
1008
                poller.register(input_socket, zmq.POLLIN)
1009

1010
            if coord_socket is not None:
1011
1012
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
1013
                poller.register(coord_socket, zmq.POLLIN)
1014

1015
1016
            ready_event.set()
            del ready_event
1017
1018
1019
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
1020
1021
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
1022
1023

                    # Deserialize the request data.
1024
1025
1026
1027
1028
                    if request_type == EngineCoreRequestType.ADD:
                        request = add_request_decoder.decode(data_frames)
                        request = self.preprocess_add_request(request)
                    else:
                        request = generic_decoder.decode(data_frames)
1029
1030
1031
1032

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

1033
1034
1035
    def process_output_sockets(
        self,
        output_paths: list[str],
1036
        coord_output_path: str | None,
1037
1038
        engine_index: int,
    ):
1039
1040
1041
        """Output socket IO thread."""

        # Msgpack serialization encoding.
1042
        encoder = MsgpackEncoder()
1043
1044
1045
1046
1047
1048
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
1049

1050
1051
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1052
1053
1054
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1055
1056
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1057
1058
                for output_path in output_paths
            ]
1059
1060
1061
1062
1063
1064
1065
1066
1067
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1068
1069
            max_reuse_bufs = len(sockets) + 1

1070
            while True:
1071
1072
1073
1074
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1075
                    break
1076
1077
                assert not isinstance(output, bytes)
                client_index, outputs = output
1078
                outputs.engine_index = engine_index
1079

1080
1081
1082
1083
1084
1085
1086
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1087
1088
1089
1090
1091
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1092
                buffers = encoder.encode_into(outputs, buffer)
1093
1094
1095
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1096
1097
1098
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1099
1100
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1101
                    reuse_buffers.append(buffer)
1102
1103
1104
1105
1106
1107
1108
1109
1110


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1111
        local_client: bool,
1112
        handshake_address: str,
1113
1114
        executor_class: type[Executor],
        log_stats: bool,
1115
        client_handshake_address: str | None = None,
1116
    ):
1117
1118
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1119
        self.step_counter = 0
1120
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1121
        self.last_counts = (0, 0)
1122
1123
1124

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1125
1126
1127
1128
1129
1130
1131
1132
1133
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
            dp_rank,
        )
1134
1135
1136

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1137
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1138
        dp_size = vllm_config.parallel_config.data_parallel_size
1139
1140
1141
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
1142
        assert local_dp_rank is not None
1143
1144
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1145
1146
1147
1148
1149
1150
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
1151
1152
1153
1154
            logger.debug(
                "Setting kv_transfer_config.engine_id to %s",
                vllm_config.kv_transfer_config.engine_id,
            )
1155

1156
        self.dp_rank = dp_rank
1157
1158
1159
1160
1161
1162
1163
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1164
1165
1166
1167
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1168
1169
1170
1171
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1172
1173
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1174

1175
        super().add_request(request, request_wave)
1176

1177
1178
1179
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1180
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1181
1182
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1183
1184
                new_wave >= self.current_wave
            ):
1185
1186
                self.current_wave = new_wave
                if not self.engines_running:
1187
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1188
1189
1190
1191
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1192
    def _maybe_publish_request_counts(self):
1193
        if not self.publish_dp_lb_stats:
1194
1195
1196
1197
1198
1199
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1200
1201
1202
1203
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1204

1205
1206
1207
1208
1209
1210
1211
1212
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1213
1214
            # 2) Step the engine core.
            executed = self._process_engine_step()
1215
1216
            self._maybe_publish_request_counts()

1217
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1218
1219
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1220
1221
1222
                    # All engines are idle.
                    continue

1223
1224
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1225
1226
1227
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1228
            self.engines_running = self._has_global_unfinished_reqs(
1229
1230
                local_unfinished_reqs
            )
1231

1232
            if not self.engines_running:
1233
                if self.dp_rank == 0 or not self.has_coordinator:
1234
                    # Notify client that we are pausing the loop.
1235
1236
1237
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1238
1239
1240
1241
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1242
                    self.output_queue.put_nowait(
1243
1244
1245
1246
1247
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1248
                # Increment wave count and reset step counter.
1249
                self.current_wave += 1
1250
                self.step_counter = 0
1251
1252

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1253
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1254
1255
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1256
1257
            return True

1258
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1259

1260
    def reinitialize_distributed(
1261
1262
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1263
1264
1265
1266
1267
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1268
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1269
        if reconfig_request.new_data_parallel_rank != -1:
1270
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1271
        # local rank specifies device visibility, it should not be changed
1272
1273
1274
1275
1276
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1277
            reconfig_request.new_data_parallel_master_ip
1278
1279
        )
        parallel_config.data_parallel_master_port = (
1280
            reconfig_request.new_data_parallel_master_port
1281
        )
1282
1283
1284
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1285
        reconfig_request.new_data_parallel_master_port = (
1286
            parallel_config.data_parallel_master_port
1287
        )
1288
1289
1290
1291
1292
1293
1294
1295

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1296
1297
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1298
1299
1300
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1301
1302
1303
1304
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1305
1306
1307
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1308
1309
1310
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1311

Rui Qiao's avatar
Rui Qiao committed
1312
1313
1314
1315
1316
1317
1318
1319
1320

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1321
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1322
1323
1324
1325
1326
1327
1328
1329
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
1330
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1331

1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1342
1343
1344
1345
1346
1347
1348
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1349
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1350

1351
        super().__init__(vllm_config, local_client, "", executor_class, log_stats)
Rui Qiao's avatar
Rui Qiao committed
1352

1353
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1354
        from vllm.platforms import current_platform
1355

1356
1357
1358
1359
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1360
1361
1362
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1363

1364
1365
1366
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1367
1368
1369
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1370
1371
1372
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1373
            os.environ[device_control_env_var] = value
1374
1375
1376
1377
1378
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1379
1380
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1381

Rui Qiao's avatar
Rui Qiao committed
1382
    @contextmanager
1383
1384
1385
1386
1387
1388
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1389
        client_handshake_address: str | None,
1390
    ):
Rui Qiao's avatar
Rui Qiao committed
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()