utils.py 48.4 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import contextlib
5
import os
6
import threading
7
import weakref
8
from collections.abc import Callable, Iterator
9
10
11
12
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess
13
from multiprocessing.queues import Queue
14
from typing import TYPE_CHECKING, cast
15
from unittest.mock import patch
16
17
18
19

import msgspec
import zmq

20
from vllm import envs
21
22
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
23
from vllm.platforms import current_platform
24
from vllm.ray.ray_env import get_env_vars_to_copy
25
from vllm.utils import numa_utils
26
from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx
27
from vllm.utils.system_utils import get_mp_context
28
from vllm.v1.engine.coordinator import DPCoordinator
29
from vllm.v1.executor import Executor
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)

STARTUP_POLL_PERIOD_MS = 10000


class CoreEngineState(Enum):
    NEW = auto()
    CONNECTED = auto()
    READY = auto()


class CoreEngine:
    """One per data parallel rank, used to track state during handshaking."""

    def __init__(self, index: int = 0, local: bool = True):
        self.local = local
        self.identity = index.to_bytes(2, "little")

        self.state = CoreEngineState.NEW


@dataclass
class EngineZmqAddresses:
    # ZMQ input socket addresses for each front-end client (requests)
    inputs: list[str]
    # ZMQ output socket addresses for each front-end client (responses)
    outputs: list[str]
    # ZMQ input socket address of DP coordinator if applicable
63
    coordinator_input: str | None = None
64
    # ZMQ output socket address of DP coordinator if applicable
65
    coordinator_output: str | None = None
66
67
68
    # ZMQ socket for front-end to connect to DP coordinator.
    # Not used by engine, just relayed to front-end in handshake response.
    # Only required for external DP LB case.
69
    frontend_stats_publish_address: str | None = None
70
71
72
73
74
75
76
77


@dataclass
class EngineHandshakeMetadata:
    """Metadata sent to each engine process during startup handshake,
    including addresses of the front-end ZMQ queues that they should
    connect to.
    """
78

79
    addresses: EngineZmqAddresses
80
    parallel_config: dict[str, int | str | list[int]]
81
82


83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def _make_control_bundle(node_ip: str) -> dict[str, float]:
    # The engine actor is scheduled on the final CPU-only bundle. Keep that
    # bundle colocated with the group's first GPU bundle so the actor does not
    # float to an unrelated node and reorder worker ranks away from the
    # advertised DP bootstrap host.
    return {"CPU": 1.0, "node:" + node_ip: 0.001}


def _get_bundle_node_ip(bundle: dict[str, float]) -> str:
    for key in bundle:
        if key.startswith("node:"):
            return key.split(":", 1)[1]
    raise ValueError(f"Missing node affinity in placement bundle: {bundle}")


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class CoreEngineProcManager:
    """
    Utility class to handle creation, readiness, and shutdown
    of background processes used by the AsyncLLM and LLMEngine.
    """

    def __init__(
        self,
        local_engine_count: int,
        start_index: int,
        local_start_index: int,
        vllm_config: VllmConfig,
        local_client: bool,
        handshake_address: str,
        executor_class: type[Executor],
        log_stats: bool,
114
        client_handshake_address: str | None = None,
115
        tensor_queue: Queue | None = None,
116
117
118
119
120
121
122
123
    ):
        context = get_mp_context()
        common_kwargs = {
            "vllm_config": vllm_config,
            "local_client": local_client,
            "handshake_address": handshake_address,
            "executor_class": executor_class,
            "log_stats": log_stats,
124
            "tensor_queue": tensor_queue,
125
126
127
        }

        if client_handshake_address:
128
            common_kwargs["client_handshake_address"] = client_handshake_address
129

130
131
132
133
        is_dp = vllm_config.parallel_config.data_parallel_size > 1

        from vllm.v1.engine.core import EngineCoreProc

134
        self.processes: list[BaseProcess] = []
135
        local_dp_ranks = []
136
137
138
        for index in range(local_engine_count):
            local_index = local_start_index + index
            global_index = start_index + index
139

140
            # Start EngineCore in background process.
141
            local_dp_ranks.append(local_index)
142
            self.processes.append(
143
                context.Process(
144
145
                    target=EngineCoreProc.run_engine_core,
                    name=f"EngineCore_DP{global_index}" if is_dp else "EngineCore",
146
                    kwargs=common_kwargs
147
                    | {"dp_rank": global_index, "local_dp_rank": local_index},
148
149
                )
            )
150
151

        self._finalizer = weakref.finalize(self, shutdown, self.processes)
152
153
        self.manager_stopped = threading.Event()
        self.failed_proc_name: str | None = None
154

155
        try:
156
            for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
157
158
                # Adjust device control in DP for non-CUDA platforms
                # as well as external and ray launchers
159
                # For CUDA platforms, we use torch.accelerator.set_device_index()()
160
161
162
                device_control_context: contextlib.AbstractContextManager[None] = (
                    contextlib.nullcontext()
                )
163
164
165
                if is_dp and (
                    not current_platform.is_cuda_alike()
                    or vllm_config.parallel_config.use_ray
166
                ):
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
                    device_control_context = set_device_control_env_var(
                        vllm_config, local_dp_rank
                    )

                with (
                    device_control_context,
                    numa_utils.configure_subprocess(
                        # EngineCore itself does not have a TP/PP-local rank.
                        # When DP is enabled, set_device_control_env_var()
                        # narrows visible devices to this DP shard first, so
                        # local_rank=0 means "the first local GPU in this
                        # shard". The actual TP/PP worker processes spawned by
                        # the executor are bound separately with their own
                        # local_rank values.
                        vllm_config,
                        local_rank=0,
                        dp_local_rank=local_dp_rank,
                        process_kind="EngineCore",
                    ),
                ):
187
                    proc.start()
188
189
190
        finally:
            # Kill other procs if not all are running.
            if self.finished_procs():
191
                self.shutdown()
192

193
194
    def shutdown(self, timeout: float | None = None) -> None:
        """Shutdown engine core processes with configurable timeout."""
195
        self.manager_stopped.set()
196
197
        if self._finalizer.detach() is not None:
            shutdown(self.processes, timeout=timeout)
198

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    def monitor_engine_liveness(self) -> None:
        """Monitor engine core process liveness."""

        sentinel_to_proc = {proc.sentinel: proc for proc in self.processes}
        sentinels = set(sentinel_to_proc.keys())

        while sentinels and not self.manager_stopped.is_set():
            died_sentinels = connection.wait(sentinels, timeout=1)

            for sentinel in died_sentinels:
                proc = sentinel_to_proc.pop(cast(int, sentinel))
                exitcode = proc.exitcode
                if exitcode != 0 and not self.manager_stopped.is_set():
                    self.failed_proc_name = proc.name
            if died_sentinels:
                # Any engine exit currently triggers a shutdown. Future
                # work (e.g., Elastic and fault-tolerant EP) will add finer-grained
                # handling for different exit scenarios.
                break

        self.shutdown()
220
221
222
223
224
225
226
227

    def sentinels(self) -> list:
        return [proc.sentinel for proc in self.processes]

    def finished_procs(self) -> dict[str, int]:
        """Returns dict of proc name -> exit code for any finished procs."""
        return {
            proc.name: proc.exitcode
228
229
            for proc in self.processes
            if proc.exitcode is not None
230
231
232
        }


233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
class SignalCallback:
    """Safely trigger a callback from signal handler context via a dedicated thread."""

    def __init__(self, callback: Callable[[], None]):
        self._callback = callback
        self._event = threading.Event()
        self._stopped = False
        self._thread = threading.Thread(
            target=self._run,
            daemon=True,
            name="signal-callback",
        )
        self._thread.start()

    def _run(self):
        self._event.wait()
        if not self._stopped:
            self._callback()

    def trigger(self):
        self._event.set()

    def stop(self):
        self._stopped = True
        self._event.set()


260
@contextlib.contextmanager
261
262
263
def set_device_control_env_var(
    vllm_config: VllmConfig, local_dp_rank: int
) -> Iterator[None]:
264
265
266
267
268
    """
    Temporarily set CUDA_VISIBLE_DEVICES or equivalent
    for engine subprocess.
    """
    world_size = vllm_config.parallel_config.world_size
269
    local_world_size = vllm_config.parallel_config.local_world_size
270
    evar = current_platform.device_control_env_var
271

272
    value = get_device_indices(evar, local_dp_rank, world_size, local_world_size)
273
    with patch.dict(os.environ, values=((evar, value),)):
274
275
276
        yield


277
def get_device_indices(
278
279
280
281
    device_control_env_var: str,
    local_dp_rank: int,
    world_size: int,
    local_world_size: int | None = None,
282
):
283
284
285
286
287
288
289
    """
    Returns a comma-separated string of device indices for the specified
    data parallel rank.

    For example, if world_size=2 and local_dp_rank=1, and there are 4 devices,
    this will select devices 2 and 3 for local_dp_rank=1.
    """
290
291
    if local_world_size is None:
        local_world_size = world_size
292
293
294
    try:
        value = ",".join(
            str(current_platform.device_id_to_physical_device_id(i))
295
296
297
298
            for i in range(
                local_dp_rank * world_size,
                local_dp_rank * world_size + local_world_size,
            )
299
        )
300
    except IndexError as e:
301
302
303
304
305
306
307
        raise Exception(
            f"Error setting {device_control_env_var}: "
            f"local range: [{local_dp_rank * world_size}, "
            f"{(local_dp_rank + 1) * world_size}) "
            "base value: "
            f'"{os.getenv(device_control_env_var)}"'
        ) from e
308
    return value
309
310


311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
class CoreEngineActorManager:
    """
    Utility class to handle creation, readiness, and shutdown
    of core engine Ray actors used by the AsyncLLM and LLMEngine.

    Different from CoreEngineProcManager, this class manages
    core engines for both local and remote nodes.
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
326
327
        placement_groups: list["PlacementGroup"] | None = None,
        local_dp_ranks: list[int] | None = None,
328
329
330
331
    ):
        import copy

        import ray
332
        from ray.runtime_env import RuntimeEnv
333
        from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
334

335
336
337
338
339
340
341
342
        from vllm.v1.engine.core import DPMoEEngineCoreActor, EngineCoreActor

        dp_size = vllm_config.parallel_config.data_parallel_size
        actor_class = (
            DPMoEEngineCoreActor
            if dp_size > 1 and vllm_config.model_config.is_moe
            else EngineCoreActor
        )
343
344
345

        self.local_engine_actors: list[ray.ActorHandle] = []
        self.remote_engine_actors: list[ray.ActorHandle] = []
346

347
        env_vars_list = get_env_vars_to_copy(destination=actor_class.__name__)
348
        self.env_vars_dict = {
349
            name: os.environ[name] for name in env_vars_list if name in os.environ
350
351
352
353
354
355
        }
        runtime_env = RuntimeEnv(env_vars=self.env_vars_dict)

        self.addresses = addresses
        self.executor_class = executor_class
        self.log_stats = log_stats
356
        local_engine_count = vllm_config.parallel_config.data_parallel_size_local
357
        world_size = vllm_config.parallel_config.world_size
358
359
        self.manager_stopped = threading.Event()
        self.failed_proc_name: str | None = None
360
361

        if ray.is_initialized():
362
            logger.info("Ray is already initialized. Skipping Ray initialization.")
363
364
365
        else:
            ray.init()

366
367
368
369
370
371
372
373
374
375
376
377
378
379
        parallel_config = vllm_config.parallel_config
        if parallel_config.enable_elastic_ep:
            from vllm.distributed.utils import create_tcp_store

            ip = parallel_config.data_parallel_master_ip
            store = create_tcp_store(
                ip,
                0,
                is_master=True,
                world_size=-1,
                wait_for_workers=False,
            )
            parallel_config._coord_store_port = store.port
            self._coord_store = store
380

381
382
        if placement_groups is not None:
            assert local_dp_ranks is not None, (
383
384
                "local_dp_ranks must be provided if placement_groups is provided"
            )
385
            assert len(placement_groups) == len(local_dp_ranks), (
386
387
                "placement_groups and local_dp_ranks must have the same length"
            )
388
389
390
391
            logger.info("Using provided placement groups")
            # TODO(rui): validate passed-in placement groups
            self.created_placement_groups = []
        else:
392
            placement_groups, local_dp_ranks = (
393
                CoreEngineActorManager.create_dp_placement_groups(vllm_config)
394
            )
395
396
            self.created_placement_groups = placement_groups
        assert len(placement_groups) == dp_size, (
397
398
            "Number of placement groups must match data parallel size"
        )
399

400
        self.placement_group_is_local = []
401
        refs = []
402
403
404
        for index, local_index, pg in zip(
            range(dp_size), local_dp_ranks, placement_groups
        ):
405
406
407
            dp_vllm_config = copy.deepcopy(vllm_config)
            dp_vllm_config.parallel_config.placement_group = pg
            local_client = index < local_engine_count
408

409
410
411
412
413
414
415
            if dp_size > 1 and dp_vllm_config.kv_transfer_config is not None:
                # modify the engine_id and append the local_dp_rank to it to ensure
                # that the kv_transfer_config is unique for each DP rank.
                dp_vllm_config.kv_transfer_config.engine_id = (
                    f"{dp_vllm_config.kv_transfer_config.engine_id}_dp{local_index}"
                )

416
417
418
419
420
421
            # Ray XPU known issue: dpctl initializes the GPU runtime early, so
            # setting device env vars in Ray actor's initialization method
            # will not affect device selection. See:
            # https://github.com/ray-project/ray/blob/master/python/ray/_private/accelerators/intel_gpu.py#L56 # noqa: E501
            if current_platform.is_xpu():
                device_evar = current_platform.device_control_env_var
422
423
424
                device_indices = get_device_indices(
                    device_evar, local_index, world_size
                )
425
426
427
428
                actor_env_vars = self.env_vars_dict.copy()
                actor_env_vars[device_evar] = device_indices
                runtime_env = RuntimeEnv(env_vars=actor_env_vars)

429
            actor = (
430
                ray.remote(actor_class)
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
                .options(
                    scheduling_strategy=PlacementGroupSchedulingStrategy(
                        placement_group=pg,
                        placement_group_bundle_index=world_size,
                    ),
                    runtime_env=runtime_env,
                )
                .remote(
                    vllm_config=dp_vllm_config,
                    executor_class=executor_class,
                    log_stats=log_stats,
                    local_client=local_client,
                    addresses=addresses,
                    dp_rank=index,
                    local_dp_rank=local_index,
                )
            )
448
449
450
451
            if local_client:
                self.local_engine_actors.append(actor)
            else:
                self.remote_engine_actors.append(actor)
452
            self.placement_group_is_local.append(local_client)
453
454
455
456
            refs.append(actor.wait_for_init.remote())

        ray.get(refs)
        self.run_refs = []
457
        self.actor_run_ref_dict = dict()
458
        for actor in self.local_engine_actors + self.remote_engine_actors:
459
460
461
            ref = actor.run.remote()
            self.run_refs.append(ref)
            self.actor_run_ref_dict[actor] = ref
462
463
464

    @staticmethod
    def create_dp_placement_groups(
465
        vllm_config: VllmConfig,
466
    ) -> tuple[list["PlacementGroup"], list[int]]:
467
468
469
        """
        Create placement groups for data parallel.
        """
470
471
472
473
474

        import ray
        from ray._private.state import available_resources_per_node

        logger.info("Creating placement groups for data parallel")
475
        dp_master_ip = vllm_config.parallel_config.data_parallel_master_ip
476
477
        dp_size = vllm_config.parallel_config.data_parallel_size
        dp_size_local = vllm_config.parallel_config.data_parallel_size_local
478
479
480
481
482

        available_resources = available_resources_per_node()
        world_size = vllm_config.parallel_config.world_size
        placement_groups: list[PlacementGroup] = []
        local_dp_ranks: list[int] = []
483

484
485
486
487
488
        dp_master_ip_key = f"node:{dp_master_ip}"
        nodes = sorted(
            available_resources.values(), key=lambda x: dp_master_ip_key not in x
        )
        assert len(nodes) > 0, "No nodes with resources found in Ray cluster."
489
        assert dp_master_ip_key in nodes[0], (
490
            f"The DP master node (ip: {dp_master_ip}) is missing or dead"
491
        )
492
        device_str = current_platform.ray_device_key
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
        n_node_devices: list[int] = [
            int(node_resources[device_str])
            for node_resources in nodes
            if device_str in node_resources
        ]
        assert n_node_devices, f"No {device_str} found in Ray cluster."
        max_device_per_node = max(n_node_devices)

        pack_strategy = envs.VLLM_RAY_DP_PACK_STRATEGY
        _supported_pack_strategies = ("strict", "fill", "span")
        if pack_strategy not in _supported_pack_strategies:
            raise ValueError(
                f"{envs.VLLM_RAY_DP_PACK_STRATEGY} is not supported. "
                "Make sure to set `VLLM_RAY_DP_PACK_STRATEGY` "
                f"to one of {_supported_pack_strategies}"
            )
509

510
        all2all_backend = vllm_config.parallel_config.all2all_backend
511
        if pack_strategy == "fill" and (
512
513
            all2all_backend == "deepep_high_throughput"
            or all2all_backend == "deepep_low_latency"
514
515
516
517
518
519
520
521
        ):
            raise ValueError(
                "DeepEP kernels require EP ranks [0,7] (same for [8,15], ...) "
                "to be on the same node, but VLLM_RAY_DP_PACK_STRATEGY=fill "
                "does not guarantee that. "
                "Please use VLLM_RAY_DP_PACK_STRATEGY=strict instead."
            )

522
523
524
525
526
527
528
529
530
531
532
533
        if pack_strategy in ("strict", "fill"):
            placement_strategy = "STRICT_PACK"
        else:
            placement_strategy = "PACK"
            assert world_size > max_device_per_node, (
                f"World size {world_size} is smaller than the "
                "maximum number of devices per node "
                f"{max_device_per_node}. Make sure to set "
                "`VLLM_RAY_DP_PACK_STRATEGY` to `strict` or `fill`"
            )

            # if we need multiple nodes per dp group, we require for now that
534
            # available nodes are homogeneous
535
            assert set(n_node_devices) == {max_device_per_node}, (
536
                f"Nodes are not homogeneous, {nodes}"
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
            )
            assert world_size % max_device_per_node == 0, (
                f"For multi-node data parallel groups, world_size ({world_size}) must "
                f"be a multiple of number of devices per node ({max_device_per_node})."
            )
            assert len(n_node_devices) * max_device_per_node >= world_size * dp_size, (
                f"Not enough total available nodes ({len(n_node_devices)}) "
                f"and devices per node ({max_device_per_node}) "
                f"to satisfy required world size {world_size} and data parallel size "
                f"{dp_size}"
            )
            assert dp_size_local == 1, (
                f"data-parallel-size-local {dp_size_local} should be set as the "
                "default (1) for VLLM_RAY_DP_PACK_STRATEGY=span. "
                "The actual data-parallel-size-local will be auto determined."
            )

        # bundles collected for a single DP rank from multiple nodes,
        # for "span" pack strategy
        collected_bundles = []
557
        for node_resources in nodes:
558
559
560
561
562
563
            node_ip_keys = [
                key
                for key in node_resources
                if key != "node:__internal_head__" and key.startswith("node:")
            ]
            assert len(node_ip_keys) == 1, (
564
                f"Zero or multiple node IP keys found in node resources: {node_ip_keys}"
565
566
567
568
            )
            node_ip_key = node_ip_keys[0]
            node_ip = node_ip_key.split(":")[1]

569
570
571
572
573
574
575
576
            n_device_on_node = int(node_resources.get(device_str, 0))
            if pack_strategy == "span" and n_device_on_node != 0:
                # Strictly speaking,
                # dp_size_available = n_device_on_node / world_size
                # and is a fraction, but we use 1 for easier processing
                dp_size_available = 1
            else:
                dp_size_available = n_device_on_node // world_size
577
578
579
580

            if node_ip == dp_master_ip:
                if dp_size_available < dp_size_local:
                    raise ValueError(
581
582
583
                        f"Not enough resources to allocate {dp_size_local} DP ranks "
                        f"on DP master node {dp_master_ip}, possible to fit "
                        f"{dp_size_available} DP ranks."
584
                    )
585
                dp_size_to_allocate = dp_size_local
586
            elif pack_strategy == "strict":
587
588
589
590
591
592
593
                if dp_size_available < dp_size_local:
                    logger.info(
                        "Skipping node %s as %s DP ranks could not fit, "
                        "possible to fit %s DP ranks",
                        node_ip,
                        dp_size_local,
                        dp_size_available,
594
                    )
595
596
597
                    continue
                dp_size_to_allocate = dp_size_local
            else:
598
599
                # for "pack_strategy" in "fill" and "span"
                # we always take everything that's available
600
601
602
                dp_size_to_allocate = dp_size_available

            for i in range(dp_size_to_allocate):
603
604
605
606
607
608
609
610
611
612
613
614
                device_bundle = [{device_str: 1.0, "node:" + node_ip: 0.001}]
                if pack_strategy == "span":
                    collected_bundles += device_bundle * n_device_on_node
                    assert len(collected_bundles) <= world_size, (
                        "collected_bundles should be <= world_size, "
                        f"but got {len(collected_bundles)=} and {world_size=}"
                    )

                    # we only create a placement group if we collected enough devices
                    if len(collected_bundles) < world_size:
                        continue

615
616
617
618
                    control_node_ip = _get_bundle_node_ip(collected_bundles[0])
                    bundles = collected_bundles + [
                        _make_control_bundle(control_node_ip)
                    ]
619
620
                    collected_bundles = []
                else:
621
622
623
624
625
626
627
628
                    # STRICT_PACK already keeps every bundle in the placement
                    # group on one node, so the explicit node affinity on the
                    # control bundle is redundant for correctness here. Keep it
                    # anyway for consistency with the span path and to preserve
                    # intent if this scheduling strategy changes later.
                    bundles = device_bundle * world_size + [
                        _make_control_bundle(node_ip)
                    ]
629

630
631
                pg = ray.util.placement_group(
                    name=f"dp_rank_{len(placement_groups)}",
632
                    strategy=placement_strategy,
633
634
635
636
                    bundles=bundles,
                )
                placement_groups.append(pg)
                local_dp_ranks.append(i)
637
638
                if len(placement_groups) == dp_size:
                    break
639
640

        if len(placement_groups) < dp_size:
641
            raise ValueError(
642
                f"Not enough resources to allocate {dp_size} "
643
644
645
                "placement groups, only created "
                f"{len(placement_groups)} placement groups. "
                "Available resources: "
646
647
                f"{available_resources}"
            )
648
649
650
651
652
653
654
        assert len(placement_groups) == dp_size, (
            f"Created {len(placement_groups)} DP placement groups, expected {dp_size}"
        )
        assert len(local_dp_ranks) == dp_size, (
            f"local_dp_ranks length {len(local_dp_ranks)} does not match "
            f"expected {dp_size}"
        )
655
656
        return placement_groups, local_dp_ranks

657
658
659
660
661
662
663
664
    @staticmethod
    def add_dp_placement_groups(
        old_vllm_config: VllmConfig, new_data_parallel_size: int
    ) -> tuple[list["PlacementGroup"], list[int]]:
        """
        Add placement groups for new data parallel size.
        """
        import ray
665
666
667
668
        from ray._private.state import (
            available_resources_per_node,
            total_resources_per_node,
        )
669
670
671
672
673
674
675
676
677
678
679
680
681
        from ray.util.state import list_nodes

        old_dp_size = old_vllm_config.parallel_config.data_parallel_size
        num_pg_to_create = new_data_parallel_size - old_dp_size

        if num_pg_to_create <= 0:
            return [], []

        dp_master_ip = old_vllm_config.parallel_config.data_parallel_master_ip
        world_size = old_vllm_config.parallel_config.world_size

        nodes = list_nodes()
        nodes = sorted(nodes, key=lambda node: node.node_ip != dp_master_ip)
682
        assert nodes[0].node_ip == dp_master_ip, "The first node must be the head node"
683
        assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
684
685
            "There can only be one head node"
        )
686
687
688
689
690
691
692
693

        available_resources = available_resources_per_node()
        total_resources = total_resources_per_node()

        placement_groups = []
        local_dp_ranks = []
        num_pg_created = 0

694
        device_str = current_platform.ray_device_key
695
696
697
698
699
700
        for node in nodes:
            if num_pg_created >= num_pg_to_create:
                break

            node_ip = node.node_ip
            node_id = node.node_id
701
702
            if device_str not in available_resources[node_id]:
                continue
703
            available_gpus = int(available_resources[node_id][device_str])
704
705
706

            # Get total GPUs on this node from the node's resources
            # Ray stores node resources with node ID as key
707
            total_gpus = int(total_resources[node_id][device_str])
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724

            # Calculate used GPUs and used engines on this node
            used_gpus = max(0, total_gpus - available_gpus)
            used_engines_on_node = used_gpus // world_size

            # Calculate how many new engines this node can accommodate
            available_engine_count = available_gpus // world_size

            # Create placement groups for new engines on this node
            for i in range(available_engine_count):
                if num_pg_created >= num_pg_to_create:
                    break

                rank = old_dp_size + num_pg_created

                # Create bundles with node constraint for master node
                if node_ip == dp_master_ip:
725
726
727
                    bundles = [
                        {device_str: 1.0, "node:" + dp_master_ip: 0.001}
                    ] * world_size + [{"CPU": 1.0}]
728
                else:
729
                    bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}]
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745

                pg = ray.util.placement_group(
                    name=f"dp_rank_{rank}",
                    strategy="STRICT_PACK",
                    bundles=bundles,
                )
                placement_groups.append(pg)

                # Local rank starts from the number of engines already used
                # on this node
                local_rank = used_engines_on_node + i
                local_dp_ranks.append(local_rank)
                num_pg_created += 1

        return placement_groups, local_dp_ranks

746
747
748
    def scale_up_elastic_ep(
        self, cur_vllm_config: VllmConfig, new_data_parallel_size: int
    ) -> None:
749
750
751
752
        import copy

        import ray
        from ray.runtime_env import RuntimeEnv
753
        from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
754

755
756
757
758
759
760
761
        from vllm.v1.engine.core import DPMoEEngineCoreActor, EngineCoreActor

        actor_class = (
            DPMoEEngineCoreActor
            if cur_vllm_config.model_config.is_moe
            else EngineCoreActor
        )
762

763
764
765
        cur_data_parallel_size = len(self.local_engine_actors) + len(
            self.remote_engine_actors
        )
766
767
768
769

        assert new_data_parallel_size > cur_data_parallel_size, (
            f"New data parallel size {new_data_parallel_size} must be greater "
            f"than current data parallel size {cur_data_parallel_size} "
770
771
            "for scale up"
        )
772

773
774
775
        placement_groups, local_dp_ranks = self.add_dp_placement_groups(
            cur_vllm_config, new_data_parallel_size
        )
776
777
778
779
780

        world_size = cur_vllm_config.parallel_config.world_size
        dp_master_ip = cur_vllm_config.parallel_config.data_parallel_master_ip
        new_local_engines = 0

781
782
783
784
        runtime_env = RuntimeEnv(
            env_vars=self.env_vars_dict | {"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": "1"}
        )
        for i, (pg, local_rank) in enumerate(zip(placement_groups, local_dp_ranks)):
785
786
            rank = cur_data_parallel_size + i
            dp_vllm_config = copy.deepcopy(cur_vllm_config)
787
            dp_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
788
789
790
791
            dp_vllm_config.parallel_config.placement_group = pg

            # Check if this placement group is on the head node
            local_client = any(
792
793
                bundle.get("node:" + dp_master_ip, 0) > 0 for bundle in pg.bundle_specs
            )
794
795
796
797
798

            if local_client:
                new_local_engines += 1
                # Update data_parallel_size_local
                dp_vllm_config.parallel_config.data_parallel_size_local = (
799
800
801
802
803
                    cur_vllm_config.parallel_config.data_parallel_size_local
                    + new_local_engines
                )

            actor = (
804
                ray.remote(actor_class)
805
806
807
808
809
810
811
812
                .options(
                    scheduling_strategy=PlacementGroupSchedulingStrategy(
                        placement_group=pg,
                        placement_group_bundle_index=world_size,
                    ),
                    runtime_env=runtime_env,
                )
                .remote(
813
814
815
816
817
818
                    vllm_config=dp_vllm_config,
                    executor_class=self.executor_class,
                    log_stats=self.log_stats,
                    local_client=local_client,
                    addresses=self.addresses,
                    dp_rank=rank,
819
820
821
                    local_dp_rank=local_rank,
                )
            )
822
823
824
825
826
827
828
829

            if local_client:
                self.local_engine_actors.append(actor)
            else:
                self.remote_engine_actors.append(actor)
            self.created_placement_groups.append(pg)
            self.placement_group_is_local.append(local_client)

830
831
832
833
834
835
836
837
838
839
840
841
842
        ray.get(
            [
                actor.wait_for_init.remote()
                for actor in (
                    self.local_engine_actors[-new_local_engines:]
                    if new_local_engines > 0
                    else []
                )
                + self.remote_engine_actors[
                    -(len(placement_groups) - new_local_engines) :
                ]
            ]
        )
843

844
845
846
847
848
        actors = (
            self.local_engine_actors[-new_local_engines:]
            if new_local_engines > 0
            else []
        ) + self.remote_engine_actors[-(len(placement_groups) - new_local_engines) :]
849
850

        for actor in actors:
851
852
853
            ref = actor.run.remote()
            self.run_refs.append(ref)
            self.actor_run_ref_dict[actor] = ref
854

855
        cur_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
856
857
858
        # Update old_vllm_config with new data_parallel_size_local if any new
        # local engines were added
        if new_local_engines > 0:
859
            cur_vllm_config.parallel_config.data_parallel_size_local += (
860
                new_local_engines
861
            )
862

863
864
865
    def scale_down_elastic_ep(
        self, cur_data_parallel_size: int, new_data_parallel_size: int
    ) -> None:
866
        import ray
867

868
869
870
        assert cur_data_parallel_size > new_data_parallel_size, (
            f"cur_data_parallel_size {cur_data_parallel_size} must be greater "
            f"than new_data_parallel_size {new_data_parallel_size} "
871
872
            "for scale down"
        )
873
874
875
876
877
878
879
880
881
        for _ in range(cur_data_parallel_size - new_data_parallel_size):
            pg = self.created_placement_groups.pop()
            is_local = self.placement_group_is_local.pop()
            if is_local:
                self.local_engine_actors.pop()
            else:
                self.remote_engine_actors.pop()
            ray.util.remove_placement_group(pg)

882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
    def remove_run_refs_for_scale_down(self, removed_dp_size: int) -> None:
        if removed_dp_size <= 0:
            return
        flags = self.placement_group_is_local[-removed_dp_size:]
        li = len(self.local_engine_actors) - 1
        ri = len(self.remote_engine_actors) - 1
        for is_local in reversed(flags):
            if is_local:
                actor = self.local_engine_actors[li]
                li -= 1
            else:
                actor = self.remote_engine_actors[ri]
                ri -= 1
            ref = self.actor_run_ref_dict.pop(actor)
            self.run_refs.remove(ref)

898
899
900
    def get_run_refs(self):
        return self.run_refs

901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
    def monitor_engine_liveness(self) -> None:
        import ray

        while not self.manager_stopped.is_set():
            actor_run_refs = list(self.get_run_refs())
            if not actor_run_refs:
                logger.info(
                    "There are no actors to monitor currently. "
                    "The monitoring function is about to terminate."
                )
                break
            actor_done_refs, _ = ray.wait(actor_run_refs, timeout=5)
            unexpected_failure = False
            for actor_ref in actor_done_refs:
                if self.manager_stopped.is_set():
                    break
                if actor_ref not in self.get_run_refs():
                    # The run refs may have been updated by elastic scale-down.
                    continue
                try:
                    ray.get(actor_ref)
                except ray.exceptions.RayActorError:
                    self.failed_proc_name = f"Actor {actor_ref}"
                    unexpected_failure = True

            if unexpected_failure:
                break

        self.shutdown()

931
    def shutdown(self, timeout: float | None = None) -> None:
932
        import ray
933

934
        self.manager_stopped.set()
935
936
937
938
939
940
        for actor in self.local_engine_actors + self.remote_engine_actors:
            ray.kill(actor)
        for pg in self.created_placement_groups:
            ray.util.remove_placement_group(pg)


941
def get_engine_zmq_addresses(
942
943
    vllm_config: VllmConfig,
    num_api_servers: int = 1,
944
945
) -> EngineZmqAddresses:
    """Allocate ZMQ addresses for engine-client communication."""
946
947
948
    parallel_config = vllm_config.parallel_config
    local_engine_count = parallel_config.data_parallel_size_local
    local_start_index = parallel_config.data_parallel_rank_local
949
    dp_size = parallel_config.data_parallel_size
950
    host = parallel_config.data_parallel_master_ip
951
    local_engines_only = parallel_config.local_engines_only
952
953
954
955
956
957
958
959

    # In offline mode there is an LLM instance per DP rank and
    # one core engine per LLM, see
    # examples/offline_inference/data_parallel.py.
    offline_mode = local_start_index is not None

    # client_local_only = True for cases where this front-end
    # sends requests only to colocated engines.
960
961
962
    client_local_only = (
        offline_mode or local_engines_only or (local_engine_count == dp_size)
    )
963
964
965
    # NOTE(yongji): handling scaling from intra-node to inter-node
    if parallel_config.enable_elastic_ep:
        client_local_only = False
966

967
    return EngineZmqAddresses(
968
969
970
971
972
973
974
975
976
977
        inputs=[
            get_engine_client_zmq_addr(client_local_only, host)
            for _ in range(num_api_servers)
        ],
        outputs=[
            get_engine_client_zmq_addr(client_local_only, host)
            for _ in range(num_api_servers)
        ],
    )

978
979
980
981
982
983
984
985
986
987
988
989
990

@contextlib.contextmanager
def launch_core_engines(
    vllm_config: VllmConfig,
    executor_class: type[Executor],
    log_stats: bool,
    addresses: EngineZmqAddresses,
    num_api_servers: int = 1,
) -> Iterator[
    tuple[
        CoreEngineProcManager | CoreEngineActorManager | None,
        DPCoordinator | None,
        EngineZmqAddresses,
991
        Queue | None,
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
    ]
]:
    """Launch engine and DP coordinator processes as needed."""

    parallel_config = vllm_config.parallel_config
    dp_size = parallel_config.data_parallel_size
    local_engine_count = parallel_config.data_parallel_size_local
    local_start_index = parallel_config.data_parallel_rank_local
    dp_rank = parallel_config.data_parallel_rank
    host = parallel_config.data_parallel_master_ip
    local_engines_only = parallel_config.local_engines_only

    offline_mode = local_start_index is not None

1006
1007
1008
1009
1010
1011
1012
1013
    # Create a single tensor IPC queue for sharing multimodal tensors between
    # API servers and engine core. Returns a single queue since we only support
    # DP=1 for this data flow.
    tensor_queue: Queue | None = None
    multimodal_config = vllm_config.model_config.multimodal_config
    if multimodal_config is not None and multimodal_config.mm_tensor_ipc == "torch_shm":
        tensor_queue = get_mp_context().Queue()

1014
1015
1016
1017
1018
1019
1020
    # Run the DP Coordinator process with rank 0 when in online DP mode.
    # The coordinator is needed for:
    # 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
    # 2. MoE models: wave coordination in addition to stats
    run_coordinator = (
        vllm_config.needs_dp_coordinator and not offline_mode and dp_rank == 0
    )
1021
1022

    if run_coordinator:
1023
1024
1025
1026
        coordinator = DPCoordinator(
            parallel_config,
            enable_wave_coordination=vllm_config.model_config.is_moe,
        )
1027
1028

        addresses.coordinator_input, addresses.coordinator_output = (
1029
1030
            coordinator.get_engine_socket_addresses()
        )
1031
        addresses.frontend_stats_publish_address = (
1032
1033
            coordinator.get_stats_publish_address()
        )
1034

1035
        logger.info("Started DP Coordinator process (PID: %d)", coordinator.proc.pid)
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
    else:
        coordinator = None

    if parallel_config.data_parallel_backend == "ray":
        logger.info("Starting ray-based data parallel backend")

        engine_actor_manager = CoreEngineActorManager(
            vllm_config=vllm_config,
            addresses=addresses,
            executor_class=executor_class,
            log_stats=log_stats,
        )

1049
        yield engine_actor_manager, coordinator, addresses, tensor_queue
1050
1051
        return

1052
    if offline_mode:
1053
1054
        assert local_engine_count == 1
        engines_to_handshake = [CoreEngine(index=dp_rank, local=True)]
1055
1056
1057
1058
1059
    elif dp_rank == 0:
        # Rank 0 holds Coordinator, so it handshakes with all Cores
        # in both external dplb and internal dplb mode.
        # Note this also covers the case where we have zero local engines
        # and rank 0 is headless.
1060
        engines_to_handshake = [
1061
            CoreEngine(index=i, local=(i < local_engine_count)) for i in range(dp_size)
1062
        ]
1063
1064
1065
1066
    else:
        # Rank > 0 handshakes with just the local cores it is managing.
        assert local_engines_only, (
            "Attempting to launch core_engines from dp_rank > 0, but "
1067
1068
            "found internal DPLB, which is incompatible."
        )
1069
1070
1071
1072
        engines_to_handshake = [
            CoreEngine(index=i, local=True)
            for i in range(dp_rank, dp_rank + local_engine_count)
        ]
1073
1074
1075
1076
1077
1078
1079

    # Whether the started engines will handshake only with co-located
    # front-end processes. In external_dp_lb mode, ranks > 0 handshake with
    # their co-located frontend and also the rank 0 front-end, and hence this
    # will be False.
    handshake_local_only = offline_mode or local_engine_count == dp_size

1080
1081
1082
1083
    # NOTE(yongji): handling scaling from intra-node to inter-node
    if parallel_config.enable_elastic_ep:
        handshake_local_only = False

1084
    handshake_address = get_engine_client_zmq_addr(
1085
1086
        handshake_local_only, host, parallel_config.data_parallel_rpc_port
    )
1087

1088
    if local_engines_only and dp_rank > 0:
1089
1090
1091
1092
1093
1094
1095
        assert not handshake_local_only
        local_handshake_address = get_open_zmq_ipc_path()
        client_handshake_address = local_handshake_address
    else:
        local_handshake_address = handshake_address
        client_handshake_address = None

1096
1097
1098
    with zmq_socket_ctx(
        local_handshake_address, zmq.ROUTER, bind=True
    ) as handshake_socket:
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
        # Start local engines.
        if local_engine_count:
            local_engine_manager = CoreEngineProcManager(
                vllm_config=vllm_config,
                executor_class=executor_class,
                log_stats=log_stats,
                handshake_address=handshake_address,
                client_handshake_address=client_handshake_address,
                local_client=True,
                local_engine_count=local_engine_count,
                start_index=dp_rank,
1110
                local_start_index=local_start_index or 0,
1111
                tensor_queue=tensor_queue,
1112
            )
1113
1114
1115
        else:
            local_engine_manager = None

1116
        yield local_engine_manager, coordinator, addresses, tensor_queue
1117
1118
1119
1120
1121
1122
1123

        # Now wait for engines to start.
        wait_for_engine_startup(
            handshake_socket,
            addresses,
            engines_to_handshake,
            parallel_config,
1124
            dp_size > 1 and vllm_config.model_config.is_moe,
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
            vllm_config.cache_config,
            local_engine_manager,
            coordinator.proc if coordinator else None,
        )


def wait_for_engine_startup(
    handshake_socket: zmq.Socket,
    addresses: EngineZmqAddresses,
    core_engines: list[CoreEngine],
    parallel_config: ParallelConfig,
1136
    coordinated_dp: bool,
1137
    cache_config: CacheConfig,
1138
1139
    proc_manager: CoreEngineProcManager | None,
    coord_process: Process | None,
1140
1141
1142
1143
1144
1145
1146
1147
1148
):
    # Wait for engine core process(es) to send ready messages.
    local_count = parallel_config.data_parallel_size_local
    remote_count = len(core_engines) - local_count
    # [local, remote] counts
    conn_pending, start_pending = [local_count, remote_count], [0, 0]
    poller = zmq.Poller()
    poller.register(handshake_socket, zmq.POLLIN)

1149
1150
    remote_should_be_headless = (
        not parallel_config.data_parallel_hybrid_lb
1151
        and not parallel_config.data_parallel_external_lb
1152
    )
1153

1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
    if proc_manager is not None:
        for sentinel in proc_manager.sentinels():
            poller.register(sentinel, zmq.POLLIN)
    if coord_process is not None:
        poller.register(coord_process.sentinel, zmq.POLLIN)
    while any(conn_pending) or any(start_pending):
        events = poller.poll(STARTUP_POLL_PERIOD_MS)
        if not events:
            if any(conn_pending):
                logger.debug(
1164
1165
1166
                    "Waiting for %d local, %d remote core engine proc(s) to connect.",
                    *conn_pending,
                )
1167
1168
            if any(start_pending):
                logger.debug(
1169
1170
1171
                    "Waiting for %d local, %d remote core engine proc(s) to start.",
                    *start_pending,
                )
1172
1173
1174
1175
1176
1177
            continue
        if len(events) > 1 or events[0][0] != handshake_socket:
            # One of the local core processes exited.
            finished = proc_manager.finished_procs() if proc_manager else {}
            if coord_process is not None and coord_process.exitcode is not None:
                finished[coord_process.name] = coord_process.exitcode
1178
1179
1180
1181
1182
            raise RuntimeError(
                "Engine core initialization failed. "
                "See root cause above. "
                f"Failed core proc(s): {finished}"
            )
1183
1184
1185
1186

        # Receive HELLO and READY messages from the input socket.
        eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
        eng_index = int.from_bytes(eng_identity, "little")
1187
        engine = next((e for e in core_engines if e.identity == eng_identity), None)
1188
        if engine is None:
1189
1190
1191
            raise RuntimeError(
                f"Message from engine with unexpected data parallel rank: {eng_index}"
            )
1192
        msg = msgspec.msgpack.decode(ready_msg_bytes)
1193
        status, local, headless = msg["status"], msg["local"], msg["headless"]
1194
        if local != engine.local:
1195
1196
1197
1198
1199
1200
            raise RuntimeError(
                f"{status} message from "
                f"{'local' if local else 'remote'} "
                f"engine {eng_index}, expected it to be "
                f"{'local' if engine.local else 'remote'}"
            )
1201

1202
1203
1204
        # Remote engines must be headless iff we aren't in hybrid dp lb mode.
        if not local and headless != remote_should_be_headless:
            if headless:
1205
1206
1207
1208
1209
                raise RuntimeError(
                    f"Remote engine {eng_index} must not use "
                    f"--headless in external or hybrid dp lb "
                    f"mode"
                )
1210
            else:
1211
1212
1213
1214
1215
                raise RuntimeError(
                    f"Remote engine {eng_index} must use "
                    f"--headless unless in external or hybrid "
                    f"dp lb mode"
                )
1216

1217
        if status == "HELLO" and engine.state == CoreEngineState.NEW:
1218
            # Send init message with DP config info.
1219
1220
1221
1222
            init_message = msgspec.msgpack.encode(
                EngineHandshakeMetadata(
                    addresses=addresses,
                    parallel_config={
1223
1224
1225
1226
1227
1228
1229
                        k: getattr(parallel_config, k)
                        for k in (
                            "data_parallel_master_ip",
                            "data_parallel_master_port",
                            "_data_parallel_master_port_list",
                            "data_parallel_size",
                        )
1230
1231
1232
                    }
                    if coordinated_dp
                    else {},
1233
1234
1235
                )
            )
            handshake_socket.send_multipart((eng_identity, init_message), copy=False)
1236
1237
1238
1239
            conn_pending[0 if local else 1] -= 1
            start_pending[0 if local else 1] += 1
            engine.state = CoreEngineState.CONNECTED
        elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
1240
1241
            # Validate config hash consistency across DP workers for MoE models.
            if coordinated_dp:
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
                worker_config_hash = msg.get("parallel_config_hash")
                expected_hash = parallel_config.compute_hash()
                if worker_config_hash != expected_hash:
                    raise RuntimeError(
                        f"Configuration mismatch detected for engine "
                        f"{eng_index}. All DP workers must have identical "
                        f"configurations for parameters that affect collective "
                        f"communication (e.g., enable_eplb, "
                        f"eplb_config.log_balancedness). "
                        f"Worker hash: {worker_config_hash}, "
                        f"Expected hash: {expected_hash}. "
                        f"Please ensure all workers are started with the same "
                        f"command-line arguments."
                    )

1257
1258
1259
            start_pending[0 if local else 1] -= 1
            engine.state = CoreEngineState.READY
        else:
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
            raise RuntimeError(
                f"Unexpected {status} message for "
                f"{'local' if local else 'remote'} engine "
                f"{eng_index} in {engine.state} state."
            )

        logger.debug(
            "%s from %s core engine process %s.",
            status,
            "local" if local else "remote",
            eng_index,
        )