"docs/vscode:/vscode.git/clone" did not exist on "815079de8e9dd984d474f7046412d5aedf4350ff"
utils.py 33.9 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import contextlib
5
import os
6
7
8
9
10
11
12
import weakref
from collections.abc import Iterator
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess
from typing import TYPE_CHECKING, Callable, Optional, Union
13
from unittest.mock import patch
14
15
16
17
18
19

import msgspec
import zmq

from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
20
from vllm.platforms import current_platform
21
from vllm.ray.ray_env import get_env_vars_to_copy
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)

STARTUP_POLL_PERIOD_MS = 10000


class CoreEngineState(Enum):
    NEW = auto()
    CONNECTED = auto()
    READY = auto()


class CoreEngine:
    """One per data parallel rank, used to track state during handshaking."""

    def __init__(self, index: int = 0, local: bool = True):
        self.local = local
        self.identity = index.to_bytes(2, "little")

        self.state = CoreEngineState.NEW


@dataclass
class EngineZmqAddresses:
    # ZMQ input socket addresses for each front-end client (requests)
    inputs: list[str]
    # ZMQ output socket addresses for each front-end client (responses)
    outputs: list[str]
    # ZMQ input socket address of DP coordinator if applicable
    coordinator_input: Optional[str] = None
    # ZMQ output socket address of DP coordinator if applicable
    coordinator_output: Optional[str] = None
    # ZMQ socket for front-end to connect to DP coordinator.
    # Not used by engine, just relayed to front-end in handshake response.
    # Only required for external DP LB case.
    frontend_stats_publish_address: Optional[str] = None


@dataclass
class EngineHandshakeMetadata:
    """Metadata sent to each engine process during startup handshake,
    including addresses of the front-end ZMQ queues that they should
    connect to.
    """
73

74
    addresses: EngineZmqAddresses
75
    parallel_config: dict[str, Union[int, str, list[int]]]
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106


class CoreEngineProcManager:
    """
    Utility class to handle creation, readiness, and shutdown
    of background processes used by the AsyncLLM and LLMEngine.
    """

    def __init__(
        self,
        target_fn: Callable,
        local_engine_count: int,
        start_index: int,
        local_start_index: int,
        vllm_config: VllmConfig,
        local_client: bool,
        handshake_address: str,
        executor_class: type[Executor],
        log_stats: bool,
        client_handshake_address: Optional[str] = None,
    ):
        context = get_mp_context()
        common_kwargs = {
            "vllm_config": vllm_config,
            "local_client": local_client,
            "handshake_address": handshake_address,
            "executor_class": executor_class,
            "log_stats": log_stats,
        }

        if client_handshake_address:
107
            common_kwargs["client_handshake_address"] = client_handshake_address
108
109

        self.processes: list[BaseProcess] = []
110
        local_dp_ranks = []
111
112
113
        for index in range(local_engine_count):
            local_index = local_start_index + index
            global_index = start_index + index
114

115
            # Start EngineCore in background process.
116
            local_dp_ranks.append(local_index)
117
            self.processes.append(
118
119
120
121
122
123
124
125
126
127
                context.Process(
                    target=target_fn,
                    name=f"EngineCore_DP{global_index}",
                    kwargs=common_kwargs
                    | {
                        "dp_rank": global_index,
                        "local_dp_rank": local_index,
                    },
                )
            )
128
129

        self._finalizer = weakref.finalize(self, shutdown, self.processes)
130
131

        data_parallel = vllm_config.parallel_config.data_parallel_size > 1
132
        try:
133
            for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
134
135
136
137
138
                with (
                    set_device_control_env_var(vllm_config, local_dp_rank)
                    if (data_parallel)
                    else contextlib.nullcontext()
                ):
139
                    proc.start()
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        finally:
            # Kill other procs if not all are running.
            if self.finished_procs():
                self.close()

    def close(self):
        """Shutdown all procs."""
        self._finalizer()

    def join_first(self):
        """Wait for any process to exit."""
        connection.wait(proc.sentinel for proc in self.processes)

    def sentinels(self) -> list:
        return [proc.sentinel for proc in self.processes]

    def finished_procs(self) -> dict[str, int]:
        """Returns dict of proc name -> exit code for any finished procs."""
        return {
            proc.name: proc.exitcode
160
161
            for proc in self.processes
            if proc.exitcode is not None
162
163
164
        }


165
@contextlib.contextmanager
166
167
168
def set_device_control_env_var(
    vllm_config: VllmConfig, local_dp_rank: int
) -> Iterator[None]:
169
170
171
172
173
174
    """
    Temporarily set CUDA_VISIBLE_DEVICES or equivalent
    for engine subprocess.
    """
    world_size = vllm_config.parallel_config.world_size
    evar = current_platform.device_control_env_var
175
176

    value = get_device_indices(evar, local_dp_rank, world_size)
177
    with patch.dict(os.environ, values=((evar, value),)):
178
179
180
        yield


181
182
183
def get_device_indices(
    device_control_env_var: str, local_dp_rank: int, world_size: int
):
184
185
186
187
188
189
190
    """
    Returns a comma-separated string of device indices for the specified
    data parallel rank.

    For example, if world_size=2 and local_dp_rank=1, and there are 4 devices,
    this will select devices 2 and 3 for local_dp_rank=1.
    """
191
192
193
    try:
        value = ",".join(
            str(current_platform.device_id_to_physical_device_id(i))
194
195
            for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size)
        )
196
    except IndexError as e:
197
198
199
200
201
202
203
        raise Exception(
            f"Error setting {device_control_env_var}: "
            f"local range: [{local_dp_rank * world_size}, "
            f"{(local_dp_rank + 1) * world_size}) "
            "base value: "
            f'"{os.getenv(device_control_env_var)}"'
        ) from e
204
    return value
205
206


207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
class CoreEngineActorManager:
    """
    Utility class to handle creation, readiness, and shutdown
    of core engine Ray actors used by the AsyncLLM and LLMEngine.

    Different from CoreEngineProcManager, this class manages
    core engines for both local and remote nodes.
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        placement_groups: Optional[list["PlacementGroup"]] = None,
        local_dp_ranks: Optional[list[int]] = None,
    ):
        import copy

        import ray
228
        from ray.runtime_env import RuntimeEnv
229
        from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
230
231
232
233
234

        from vllm.v1.engine.core import DPEngineCoreActor

        self.local_engine_actors: list[ray.ActorHandle] = []
        self.remote_engine_actors: list[ray.ActorHandle] = []
235
236
237

        env_vars_list = get_env_vars_to_copy(destination="DPEngineCoreActor")
        self.env_vars_dict = {
238
            name: os.environ[name] for name in env_vars_list if name in os.environ
239
240
241
242
243
244
        }
        runtime_env = RuntimeEnv(env_vars=self.env_vars_dict)

        self.addresses = addresses
        self.executor_class = executor_class
        self.log_stats = log_stats
245
        dp_size = vllm_config.parallel_config.data_parallel_size
246
        local_engine_count = vllm_config.parallel_config.data_parallel_size_local
247
248
249
        world_size = vllm_config.parallel_config.world_size

        if ray.is_initialized():
250
            logger.info("Ray is already initialized. Skipping Ray initialization.")
251
252
253
254
255
        else:
            ray.init()

        if placement_groups is not None:
            assert local_dp_ranks is not None, (
256
257
                "local_dp_ranks must be provided if placement_groups is provided"
            )
258
            assert len(placement_groups) == len(local_dp_ranks), (
259
260
                "placement_groups and local_dp_ranks must have the same length"
            )
261
262
263
264
            logger.info("Using provided placement groups")
            # TODO(rui): validate passed-in placement groups
            self.created_placement_groups = []
        else:
265
            placement_groups, local_dp_ranks = (
266
                CoreEngineActorManager.create_dp_placement_groups(vllm_config)
267
            )
268
269
            self.created_placement_groups = placement_groups
        assert len(placement_groups) == dp_size, (
270
271
            "Number of placement groups must match data parallel size"
        )
272

273
        self.placement_group_is_local = []
274
        refs = []
275
276
277
        for index, local_index, pg in zip(
            range(dp_size), local_dp_ranks, placement_groups
        ):
278
279
280
            dp_vllm_config = copy.deepcopy(vllm_config)
            dp_vllm_config.parallel_config.placement_group = pg
            local_client = index < local_engine_count
281
282
283
284
285
286
287

            # Ray XPU known issue: dpctl initializes the GPU runtime early, so
            # setting device env vars in Ray actor's initialization method
            # will not affect device selection. See:
            # https://github.com/ray-project/ray/blob/master/python/ray/_private/accelerators/intel_gpu.py#L56 # noqa: E501
            if current_platform.is_xpu():
                device_evar = current_platform.device_control_env_var
288
289
290
                device_indices = get_device_indices(
                    device_evar, local_index, world_size
                )
291
292
293
294
                actor_env_vars = self.env_vars_dict.copy()
                actor_env_vars[device_evar] = device_indices
                runtime_env = RuntimeEnv(env_vars=actor_env_vars)

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
            actor = (
                ray.remote(DPEngineCoreActor)
                .options(
                    scheduling_strategy=PlacementGroupSchedulingStrategy(
                        placement_group=pg,
                        placement_group_bundle_index=world_size,
                    ),
                    runtime_env=runtime_env,
                )
                .remote(
                    vllm_config=dp_vllm_config,
                    executor_class=executor_class,
                    log_stats=log_stats,
                    local_client=local_client,
                    addresses=addresses,
                    dp_rank=index,
                    local_dp_rank=local_index,
                )
            )
314
315
316
317
            if local_client:
                self.local_engine_actors.append(actor)
            else:
                self.remote_engine_actors.append(actor)
318
            self.placement_group_is_local.append(local_client)
319
320
321
322
323
324
325
326
327
            refs.append(actor.wait_for_init.remote())

        ray.get(refs)
        self.run_refs = []
        for actor in self.local_engine_actors + self.remote_engine_actors:
            self.run_refs.append(actor.run.remote())

    @staticmethod
    def create_dp_placement_groups(
328
        vllm_config: VllmConfig,
329
    ) -> tuple[list["PlacementGroup"], list[int]]:
330
331
332
        """
        Create placement groups for data parallel.
        """
333
334
335
336
337

        import ray
        from ray._private.state import available_resources_per_node

        logger.info("Creating placement groups for data parallel")
338
        dp_master_ip = vllm_config.parallel_config.data_parallel_master_ip
339
        num_pg_to_create = vllm_config.parallel_config.data_parallel_size
340
        local_engine_count = vllm_config.parallel_config.data_parallel_size_local
341
342
343
344
345

        available_resources = available_resources_per_node()
        world_size = vllm_config.parallel_config.world_size
        placement_groups: list[PlacementGroup] = []
        local_dp_ranks: list[int] = []
346
347
348
349
350
        dp_master_ip_key = f"node:{dp_master_ip}"
        nodes = sorted(
            available_resources.values(), key=lambda x: dp_master_ip_key not in x
        )
        assert len(nodes) > 0, "No nodes with resources found in Ray cluster."
351
        assert dp_master_ip_key in nodes[0], (
352
353
354
            "The DP master node (ip: %s) is missing or dead",
            dp_master_ip,
        )
355
        device_str = current_platform.ray_device_key
356
        for node_resources in nodes:
357
            if device_str not in node_resources:
358
                continue
359
360
361
            # For now, each DP rank can only be assigned to one node
            # TODO(rui): support allocating a single DP rank
            # to multiple nodes
362
            available_engine_count = int(node_resources[device_str]) // world_size
363
            if dp_master_ip_key in node_resources:
364
365
                assert available_engine_count >= local_engine_count, (
                    "Not enough resources to allocate DP ranks "
366
367
                    f"on DP master node {dp_master_ip}"
                )
368
                for i in range(local_engine_count):
369
370
371
                    bundles = [
                        {device_str: 1.0, "node:" + dp_master_ip: 0.001}
                    ] * world_size + [{"CPU": 1.0}]
372
373
374
375
376
377
378
379
380
                    pg = ray.util.placement_group(
                        name=f"dp_rank_{len(placement_groups)}",
                        strategy="STRICT_PACK",
                        bundles=bundles,
                    )
                    placement_groups.append(pg)
                    local_dp_ranks.append(i)
            else:
                for i in range(available_engine_count):
381
                    if len(placement_groups) == num_pg_to_create:
382
                        break
383
                    bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}]
384
385
386
387
388
389
390
                    pg = ray.util.placement_group(
                        name=f"dp_rank_{len(placement_groups)}",
                        strategy="STRICT_PACK",
                        bundles=bundles,
                    )
                    placement_groups.append(pg)
                    local_dp_ranks.append(i)
391
392
393
394
395
396
        if len(placement_groups) < num_pg_to_create:
            raise ValueError(
                f"Not enough resources to allocate {num_pg_to_create} "
                "placement groups, only created "
                f"{len(placement_groups)} placement groups. "
                "Available resources: "
397
398
                f"{available_resources}"
            )
399
400
        return placement_groups, local_dp_ranks

401
402
403
404
405
406
407
408
    @staticmethod
    def add_dp_placement_groups(
        old_vllm_config: VllmConfig, new_data_parallel_size: int
    ) -> tuple[list["PlacementGroup"], list[int]]:
        """
        Add placement groups for new data parallel size.
        """
        import ray
409
410
411
412
        from ray._private.state import (
            available_resources_per_node,
            total_resources_per_node,
        )
413
414
415
416
417
418
419
420
421
422
423
424
425
        from ray.util.state import list_nodes

        old_dp_size = old_vllm_config.parallel_config.data_parallel_size
        num_pg_to_create = new_data_parallel_size - old_dp_size

        if num_pg_to_create <= 0:
            return [], []

        dp_master_ip = old_vllm_config.parallel_config.data_parallel_master_ip
        world_size = old_vllm_config.parallel_config.world_size

        nodes = list_nodes()
        nodes = sorted(nodes, key=lambda node: node.node_ip != dp_master_ip)
426
        assert nodes[0].node_ip == dp_master_ip, "The first node must be the head node"
427
        assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
428
429
            "There can only be one head node"
        )
430
431
432
433
434
435
436
437

        available_resources = available_resources_per_node()
        total_resources = total_resources_per_node()

        placement_groups = []
        local_dp_ranks = []
        num_pg_created = 0

438
        device_str = current_platform.ray_device_key
439
440
441
442
443
444
        for node in nodes:
            if num_pg_created >= num_pg_to_create:
                break

            node_ip = node.node_ip
            node_id = node.node_id
445
            available_gpus = int(available_resources[node_id][device_str])
446
447
448

            # Get total GPUs on this node from the node's resources
            # Ray stores node resources with node ID as key
449
            total_gpus = int(total_resources[node_id][device_str])
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466

            # Calculate used GPUs and used engines on this node
            used_gpus = max(0, total_gpus - available_gpus)
            used_engines_on_node = used_gpus // world_size

            # Calculate how many new engines this node can accommodate
            available_engine_count = available_gpus // world_size

            # Create placement groups for new engines on this node
            for i in range(available_engine_count):
                if num_pg_created >= num_pg_to_create:
                    break

                rank = old_dp_size + num_pg_created

                # Create bundles with node constraint for master node
                if node_ip == dp_master_ip:
467
468
469
                    bundles = [
                        {device_str: 1.0, "node:" + dp_master_ip: 0.001}
                    ] * world_size + [{"CPU": 1.0}]
470
                else:
471
                    bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}]
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487

                pg = ray.util.placement_group(
                    name=f"dp_rank_{rank}",
                    strategy="STRICT_PACK",
                    bundles=bundles,
                )
                placement_groups.append(pg)

                # Local rank starts from the number of engines already used
                # on this node
                local_rank = used_engines_on_node + i
                local_dp_ranks.append(local_rank)
                num_pg_created += 1

        return placement_groups, local_dp_ranks

488
489
490
    def scale_up_elastic_ep(
        self, cur_vllm_config: VllmConfig, new_data_parallel_size: int
    ) -> None:
491
492
493
494
        import copy

        import ray
        from ray.runtime_env import RuntimeEnv
495
        from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
496
497
498

        from vllm.v1.engine.core import DPEngineCoreActor

499
500
501
        cur_data_parallel_size = len(self.local_engine_actors) + len(
            self.remote_engine_actors
        )
502
503
504
505

        assert new_data_parallel_size > cur_data_parallel_size, (
            f"New data parallel size {new_data_parallel_size} must be greater "
            f"than current data parallel size {cur_data_parallel_size} "
506
507
            "for scale up"
        )
508

509
510
511
        placement_groups, local_dp_ranks = self.add_dp_placement_groups(
            cur_vllm_config, new_data_parallel_size
        )
512
513
514
515
516

        world_size = cur_vllm_config.parallel_config.world_size
        dp_master_ip = cur_vllm_config.parallel_config.data_parallel_master_ip
        new_local_engines = 0

517
518
519
520
        runtime_env = RuntimeEnv(
            env_vars=self.env_vars_dict | {"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": "1"}
        )
        for i, (pg, local_rank) in enumerate(zip(placement_groups, local_dp_ranks)):
521
522
            rank = cur_data_parallel_size + i
            dp_vllm_config = copy.deepcopy(cur_vllm_config)
523
            dp_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
524
525
526
527
            dp_vllm_config.parallel_config.placement_group = pg

            # Check if this placement group is on the head node
            local_client = any(
528
529
                bundle.get("node:" + dp_master_ip, 0) > 0 for bundle in pg.bundle_specs
            )
530
531
532
533
534

            if local_client:
                new_local_engines += 1
                # Update data_parallel_size_local
                dp_vllm_config.parallel_config.data_parallel_size_local = (
535
536
537
538
539
540
541
542
543
544
545
546
547
548
                    cur_vllm_config.parallel_config.data_parallel_size_local
                    + new_local_engines
                )

            actor = (
                ray.remote(DPEngineCoreActor)
                .options(
                    scheduling_strategy=PlacementGroupSchedulingStrategy(
                        placement_group=pg,
                        placement_group_bundle_index=world_size,
                    ),
                    runtime_env=runtime_env,
                )
                .remote(
549
550
551
552
553
554
                    vllm_config=dp_vllm_config,
                    executor_class=self.executor_class,
                    log_stats=self.log_stats,
                    local_client=local_client,
                    addresses=self.addresses,
                    dp_rank=rank,
555
556
557
                    local_dp_rank=local_rank,
                )
            )
558
559
560
561
562
563
564
565

            if local_client:
                self.local_engine_actors.append(actor)
            else:
                self.remote_engine_actors.append(actor)
            self.created_placement_groups.append(pg)
            self.placement_group_is_local.append(local_client)

566
567
568
569
570
571
572
573
574
575
576
577
578
        ray.get(
            [
                actor.wait_for_init.remote()
                for actor in (
                    self.local_engine_actors[-new_local_engines:]
                    if new_local_engines > 0
                    else []
                )
                + self.remote_engine_actors[
                    -(len(placement_groups) - new_local_engines) :
                ]
            ]
        )
579

580
581
582
583
584
        actors = (
            self.local_engine_actors[-new_local_engines:]
            if new_local_engines > 0
            else []
        ) + self.remote_engine_actors[-(len(placement_groups) - new_local_engines) :]
585
586
587
588

        for actor in actors:
            self.run_refs.append(actor.run.remote())

589
        cur_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
590
591
592
        # Update old_vllm_config with new data_parallel_size_local if any new
        # local engines were added
        if new_local_engines > 0:
593
            cur_vllm_config.parallel_config.data_parallel_size_local += (
594
                new_local_engines
595
            )
596

597
598
599
    def scale_down_elastic_ep(
        self, cur_data_parallel_size: int, new_data_parallel_size: int
    ) -> None:
600
        import ray
601

602
603
604
        assert cur_data_parallel_size > new_data_parallel_size, (
            f"cur_data_parallel_size {cur_data_parallel_size} must be greater "
            f"than new_data_parallel_size {new_data_parallel_size} "
605
606
            "for scale down"
        )
607
608
609
610
611
612
613
614
615
        for _ in range(cur_data_parallel_size - new_data_parallel_size):
            pg = self.created_placement_groups.pop()
            is_local = self.placement_group_is_local.pop()
            if is_local:
                self.local_engine_actors.pop()
            else:
                self.remote_engine_actors.pop()
            ray.util.remove_placement_group(pg)

616
617
618
619
620
    def get_run_refs(self):
        return self.run_refs

    def close(self):
        import ray
621

622
623
624
625
626
627
628
629
630
631
632
633
        for actor in self.local_engine_actors + self.remote_engine_actors:
            ray.kill(actor)
        for pg in self.created_placement_groups:
            ray.util.remove_placement_group(pg)


@contextlib.contextmanager
def launch_core_engines(
    vllm_config: VllmConfig,
    executor_class: type[Executor],
    log_stats: bool,
    num_api_servers: int = 1,
634
635
) -> Iterator[
    tuple[
636
637
638
        Optional[Union[CoreEngineProcManager, CoreEngineActorManager]],
        Optional[DPCoordinator],
        EngineZmqAddresses,
639
640
    ]
]:
641
642
643
644
645
646
647
648
    """Launch engine and DP coordinator processes as needed."""

    parallel_config = vllm_config.parallel_config
    dp_size = parallel_config.data_parallel_size
    local_engine_count = parallel_config.data_parallel_size_local
    local_start_index = parallel_config.data_parallel_rank_local
    dp_rank = parallel_config.data_parallel_rank
    host = parallel_config.data_parallel_master_ip
649
650
651
652
    local_engines_only = (
        parallel_config.data_parallel_hybrid_lb
        or parallel_config.data_parallel_external_lb
    )
653
654
655
656
657
658
659
660

    # In offline mode there is an LLM instance per DP rank and
    # one core engine per LLM, see
    # examples/offline_inference/data_parallel.py.
    offline_mode = local_start_index is not None

    # client_local_only = True for cases where this front-end
    # sends requests only to colocated engines.
661
662
663
    client_local_only = (
        offline_mode or local_engines_only or (local_engine_count == dp_size)
    )
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684

    # Set up input and output addresses.
    addresses = EngineZmqAddresses(
        inputs=[
            get_engine_client_zmq_addr(client_local_only, host)
            for _ in range(num_api_servers)
        ],
        outputs=[
            get_engine_client_zmq_addr(client_local_only, host)
            for _ in range(num_api_servers)
        ],
    )

    # Run the DP Coordinator process with rank 0 when in
    # online DP mode.
    run_coordinator = dp_size > 1 and not offline_mode and dp_rank == 0

    if run_coordinator:
        coordinator = DPCoordinator(parallel_config)

        addresses.coordinator_input, addresses.coordinator_output = (
685
686
            coordinator.get_engine_socket_addresses()
        )
687
        addresses.frontend_stats_publish_address = (
688
689
            coordinator.get_stats_publish_address()
        )
690

691
        logger.info("Started DP Coordinator process (PID: %d)", coordinator.proc.pid)
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
    else:
        coordinator = None

    if parallel_config.data_parallel_backend == "ray":
        logger.info("Starting ray-based data parallel backend")

        engine_actor_manager = CoreEngineActorManager(
            vllm_config=vllm_config,
            addresses=addresses,
            executor_class=executor_class,
            log_stats=log_stats,
        )

        yield engine_actor_manager, coordinator, addresses
        return

708
    if offline_mode:
709
710
        assert local_engine_count == 1
        engines_to_handshake = [CoreEngine(index=dp_rank, local=True)]
711
712
713
714
715
    elif dp_rank == 0:
        # Rank 0 holds Coordinator, so it handshakes with all Cores
        # in both external dplb and internal dplb mode.
        # Note this also covers the case where we have zero local engines
        # and rank 0 is headless.
716
        engines_to_handshake = [
717
            CoreEngine(index=i, local=(i < local_engine_count)) for i in range(dp_size)
718
        ]
719
720
721
722
    else:
        # Rank > 0 handshakes with just the local cores it is managing.
        assert local_engines_only, (
            "Attempting to launch core_engines from dp_rank > 0, but "
723
724
            "found internal DPLB, which is incompatible."
        )
725
726
727
728
        engines_to_handshake = [
            CoreEngine(index=i, local=True)
            for i in range(dp_rank, dp_rank + local_engine_count)
        ]
729
730
731
732
733
734
735
736

    # Whether the started engines will handshake only with co-located
    # front-end processes. In external_dp_lb mode, ranks > 0 handshake with
    # their co-located frontend and also the rank 0 front-end, and hence this
    # will be False.
    handshake_local_only = offline_mode or local_engine_count == dp_size

    handshake_address = get_engine_client_zmq_addr(
737
738
        handshake_local_only, host, parallel_config.data_parallel_rpc_port
    )
739

740
    if local_engines_only and dp_rank > 0:
741
742
743
744
745
746
747
        assert not handshake_local_only
        local_handshake_address = get_open_zmq_ipc_path()
        client_handshake_address = local_handshake_address
    else:
        local_handshake_address = handshake_address
        client_handshake_address = None

748
749
750
    with zmq_socket_ctx(
        local_handshake_address, zmq.ROUTER, bind=True
    ) as handshake_socket:
751
752
753
754
755
756
757
758
759
760
761
762
763
764
        from vllm.v1.engine.core import EngineCoreProc

        # Start local engines.
        if local_engine_count:
            local_engine_manager = CoreEngineProcManager(
                EngineCoreProc.run_engine_core,
                vllm_config=vllm_config,
                executor_class=executor_class,
                log_stats=log_stats,
                handshake_address=handshake_address,
                client_handshake_address=client_handshake_address,
                local_client=True,
                local_engine_count=local_engine_count,
                start_index=dp_rank,
765
766
                local_start_index=local_start_index or 0,
            )
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
        else:
            local_engine_manager = None

        yield local_engine_manager, coordinator, addresses

        # Now wait for engines to start.
        wait_for_engine_startup(
            handshake_socket,
            addresses,
            engines_to_handshake,
            parallel_config,
            vllm_config.cache_config,
            local_engine_manager,
            coordinator.proc if coordinator else None,
        )


def wait_for_engine_startup(
    handshake_socket: zmq.Socket,
    addresses: EngineZmqAddresses,
    core_engines: list[CoreEngine],
    parallel_config: ParallelConfig,
    cache_config: CacheConfig,
    proc_manager: Optional[CoreEngineProcManager],
    coord_process: Optional[Process],
):
    # Wait for engine core process(es) to send ready messages.
    local_count = parallel_config.data_parallel_size_local
    remote_count = len(core_engines) - local_count
    # [local, remote] counts
    conn_pending, start_pending = [local_count, remote_count], [0, 0]
    poller = zmq.Poller()
    poller.register(handshake_socket, zmq.POLLIN)

801
802
    remote_should_be_headless = (
        not parallel_config.data_parallel_hybrid_lb
803
        and not parallel_config.data_parallel_external_lb
804
    )
805

806
807
808
809
810
811
812
813
814
815
    if proc_manager is not None:
        for sentinel in proc_manager.sentinels():
            poller.register(sentinel, zmq.POLLIN)
    if coord_process is not None:
        poller.register(coord_process.sentinel, zmq.POLLIN)
    while any(conn_pending) or any(start_pending):
        events = poller.poll(STARTUP_POLL_PERIOD_MS)
        if not events:
            if any(conn_pending):
                logger.debug(
816
817
818
                    "Waiting for %d local, %d remote core engine proc(s) to connect.",
                    *conn_pending,
                )
819
820
            if any(start_pending):
                logger.debug(
821
822
823
                    "Waiting for %d local, %d remote core engine proc(s) to start.",
                    *start_pending,
                )
824
825
826
827
828
829
            continue
        if len(events) > 1 or events[0][0] != handshake_socket:
            # One of the local core processes exited.
            finished = proc_manager.finished_procs() if proc_manager else {}
            if coord_process is not None and coord_process.exitcode is not None:
                finished[coord_process.name] = coord_process.exitcode
830
831
832
833
834
            raise RuntimeError(
                "Engine core initialization failed. "
                "See root cause above. "
                f"Failed core proc(s): {finished}"
            )
835
836
837
838

        # Receive HELLO and READY messages from the input socket.
        eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
        eng_index = int.from_bytes(eng_identity, "little")
839
        engine = next((e for e in core_engines if e.identity == eng_identity), None)
840
        if engine is None:
841
842
843
            raise RuntimeError(
                f"Message from engine with unexpected data parallel rank: {eng_index}"
            )
844
        msg = msgspec.msgpack.decode(ready_msg_bytes)
845
        status, local, headless = msg["status"], msg["local"], msg["headless"]
846
        if local != engine.local:
847
848
849
850
851
852
            raise RuntimeError(
                f"{status} message from "
                f"{'local' if local else 'remote'} "
                f"engine {eng_index}, expected it to be "
                f"{'local' if engine.local else 'remote'}"
            )
853

854
855
856
        # Remote engines must be headless iff we aren't in hybrid dp lb mode.
        if not local and headless != remote_should_be_headless:
            if headless:
857
858
859
860
861
                raise RuntimeError(
                    f"Remote engine {eng_index} must not use "
                    f"--headless in external or hybrid dp lb "
                    f"mode"
                )
862
            else:
863
864
865
866
867
                raise RuntimeError(
                    f"Remote engine {eng_index} must use "
                    f"--headless unless in external or hybrid "
                    f"dp lb mode"
                )
868

869
870
871
872
873
874
        if status == "HELLO" and engine.state == CoreEngineState.NEW:
            # Send init message with DP config info.
            init_message = msgspec.msgpack.encode(
                EngineHandshakeMetadata(
                    addresses=addresses,
                    parallel_config={
875
876
877
878
879
880
881
                        k: getattr(parallel_config, k)
                        for k in (
                            "data_parallel_master_ip",
                            "data_parallel_master_port",
                            "_data_parallel_master_port_list",
                            "data_parallel_size",
                        )
882
883
884
885
                    },
                )
            )
            handshake_socket.send_multipart((eng_identity, init_message), copy=False)
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
            conn_pending[0 if local else 1] -= 1
            start_pending[0 if local else 1] += 1
            engine.state = CoreEngineState.CONNECTED
        elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
            # Setup KV cache config with initialization state from
            # engine core process. Sum values from all engines in DP case.
            num_gpu_blocks = cache_config.num_gpu_blocks or 0
            num_gpu_blocks += msg["num_gpu_blocks"]
            cache_config.num_gpu_blocks = num_gpu_blocks

            # In external DP LB mode, the coordinator address that the
            # front-end procs connect to is obtained from rank 0 via
            # one of the engine handshakes, and passed to the local
            # front-end process in the response from the other.
            if addresses.frontend_stats_publish_address is None:
901
                addresses.frontend_stats_publish_address = msg.get("dp_stats_address")
902
903
904
905

            start_pending[0 if local else 1] -= 1
            engine.state = CoreEngineState.READY
        else:
906
907
908
909
910
911
912
913
914
915
916
917
            raise RuntimeError(
                f"Unexpected {status} message for "
                f"{'local' if local else 'remote'} engine "
                f"{eng_index} in {engine.state} state."
            )

        logger.debug(
            "%s from %s core engine process %s.",
            status,
            "local" if local else "remote",
            eng_index,
        )