parallel.py 35.9 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import os
5
from collections.abc import Callable
6
from typing import TYPE_CHECKING, Any, Literal, overload
7
8

import torch
9
from pydantic import Field, field_validator, model_validator
10
from torch.distributed import ProcessGroup, ReduceOp, Store
11
12
13
14
15
from typing_extensions import Self

import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
16
from vllm.model_executor.layers.batch_invariant import (
17
    vllm_is_batch_invariant,
18
)
19
from vllm.platforms import current_platform
20
from vllm.utils.network_utils import get_open_ports_list
21
from vllm.utils.torch_utils import cuda_device_count_stateless
22
23
24
25
26

if TYPE_CHECKING:
    from ray.runtime_env import RuntimeEnv
    from ray.util.placement_group import PlacementGroup

27
    from vllm.v1.executor import Executor
28
29
30
else:
    RuntimeEnv = Any
    PlacementGroup = Any
31
    Executor = Any
32
33
34

logger = init_logger(__name__)

35
ExpertPlacementStrategy = Literal["linear", "round_robin"]
36
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
37
DataParallelBackend = Literal["ray", "mp"]
Mercykid-bash's avatar
Mercykid-bash committed
38
EPLBPolicyOption = Literal["default"]
39
DCPCommBackend = Literal["ag_rs", "a2a"]
40
41
42
43
44
All2AllBackend = Literal[
    "naive",
    "pplx",
    "deepep_high_throughput",
    "deepep_low_latency",
45
    "mori",
46
47
48
    "allgather_reducescatter",
    "flashinfer_all2allv",
]
49
50


51
52
53
54
55
56
57
58
59
60
61
62
63
64
@config
class EPLBConfig:
    """Configuration for Expert Parallel Load Balancing (EP)."""

    window_size: int = 1000
    """Window size for expert load recording."""
    step_interval: int = 3000
    """
    Interval for rearranging experts in expert parallelism.

    Note that if this is greater than the EPLB window size, only the metrics
    of the last `lb_window_size` steps will be used for rearranging experts.
    """

65
    num_redundant_experts: int = Field(default=0, ge=0)
66
67
68
69
70
71
72
    """Number of redundant experts to use for expert parallelism."""

    log_balancedness: bool = False
    """
    Log the balancedness each step of expert parallelism.
    This is turned off by default since it will cause communication overhead.
    """
73
74
75
76
    log_balancedness_interval: int = 1
    """
    Interval for logging the balancedness.
    """
77
78
79
80
    use_async: bool = False
    """
    Whether to use non-blocking EPLB.
    """
81

Mercykid-bash's avatar
Mercykid-bash committed
82
83
84
    policy: EPLBPolicyOption = "default"
    """The policy type for expert parallel load balancing (EPLB)."""

85
86
87
88
89
90
91
92
    @model_validator(mode="after")
    def _validate_eplb_config(self) -> Self:
        if self.use_async and self.policy != "default":
            raise ValueError("Async EPLB is only supported with the default policy.")
        if self.log_balancedness and self.log_balancedness_interval <= 0:
            raise ValueError("log_balancedness_interval must be greater than 0.")
        return self

93

94
95
96
97
98
99
100
101
@config
class ParallelConfig:
    """Configuration for the distributed execution."""

    pipeline_parallel_size: int = 1
    """Number of pipeline parallel groups."""
    tensor_parallel_size: int = 1
    """Number of tensor parallel groups."""
102
103
    prefill_context_parallel_size: int = 1
    """Number of prefill context parallel groups."""
104
105
106
107
108
109
110
    data_parallel_size: int = 1
    """Number of data parallel groups. MoE layers will be sharded according to
    the product of the tensor parallel size and data parallel size."""
    data_parallel_size_local: int = 1
    """Number of local data parallel groups."""
    data_parallel_rank: int = 0
    """Rank of the data parallel group."""
111
    data_parallel_rank_local: int | None = None
112
113
114
115
116
117
118
119
    """Local rank of the data parallel group,
    set only in SPMD mode."""
    data_parallel_master_ip: str = "127.0.0.1"
    """IP of the data parallel master."""
    data_parallel_rpc_port: int = 29550
    """Port for data parallel messaging."""
    data_parallel_master_port: int = 29500
    """Port of the data parallel master."""
120
    data_parallel_backend: DataParallelBackend = "mp"
121
122
123
124
    """Backend to use for data parallel, either "mp" or "ray"."""
    data_parallel_external_lb: bool = False
    """Whether to use "external" DP LB mode. Applies only to online serving
    and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
co63oc's avatar
co63oc committed
125
    wide-EP setup in Kubernetes. Set implicitly when --data-parallel-rank
126
127
128
129
130
131
132
133
    is provided explicitly to vllm serve."""
    data_parallel_hybrid_lb: bool = False
    """Whether to use "hybrid" DP LB mode. Applies only to online serving
    and when data_parallel_size > 0. Enables running an AsyncLLM
    and API server on a "per-node" basis where vLLM load balances
    between local data parallel ranks, but an external LB balances
    between vLLM nodes/replicas. Set explicitly in conjunction with
    --data-parallel-start-rank."""
134
135
    is_moe_model: bool | None = None
    """Whether the deployed model is MoE (if known)."""
136
137
138
139
    enable_expert_parallel: bool = False
    """Use expert parallelism instead of tensor parallelism for MoE layers."""
    enable_eplb: bool = False
    """Enable expert parallelism load balancing for MoE layers."""
140
    eplb_config: EPLBConfig = Field(default_factory=EPLBConfig)
141
    """Expert parallelism configuration."""
142
143
144
145
146
147
148
149
150
    expert_placement_strategy: ExpertPlacementStrategy = "linear"
    """The expert placement strategy for MoE layers:\n
    - "linear": Experts are placed in a contiguous manner. For example, with 4
      experts and 2 ranks, rank 0 will have experts [0, 1] and rank 1 will have
      experts [2, 3].\n
    - "round_robin": Experts are placed in a round-robin manner. For example,
      with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
      will have experts [1, 3]. This strategy can help improve load balancing
      for grouped expert models with no redundant experts."""
151
152
153
154
155
156
157
    all2all_backend: All2AllBackend = "allgather_reducescatter"
    """All2All backend for MoE expert parallel communication. Available options:

    - "naive": Naive all2all implementation using broadcasts\n
    - "allgather_reducescatter": All2all based on allgather and reducescatter\n
    - "deepep_high_throughput": Use deepep high-throughput kernels\n
    - "deepep_low_latency": Use deepep low-latency kernels\n
158
    - "mori": Use mori kernels\n
159
    - "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
160

161
    max_parallel_loading_workers: int | None = None
162
163
164
165
166
167
168
    """Maximum number of parallel loading workers when loading model
    sequentially in multiple batches. To avoid RAM OOM when using tensor
    parallel and large models."""

    disable_custom_all_reduce: bool = False
    """Disable the custom all-reduce kernel and fall back to NCCL."""

169
170
171
    enable_elastic_ep: bool = False
    """Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""

172
    enable_dbo: bool = False
173
    """Enable dual batch overlap for the model executor."""
174
175
    ubatch_size: int = 0
    """Number of ubatch size."""
176
177

    dbo_decode_token_threshold: int = 32
178
179
180
181
182
183
184
185
186
    """The threshold for dual batch overlap for batches only containing decodes.
    If the number of tokens in the request is greater than this threshold,
    microbatching will be used. Otherwise, the request will be processed in a
    single batch."""
    dbo_prefill_token_threshold: int = 512  # TODO(lucas): tune
    """The threshold for dual batch overlap for batches that contain one or more
    prefills. If the number of tokens in the request is greater than this
    threshold, microbatching will be used. Otherwise, the request will be
    processed in a single batch."""
187

188
    disable_nccl_for_dp_synchronization: bool | None = Field(default=None)
189
    """Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py 
190
191
192
193
    to use Gloo instead of NCCL for its all reduce.

    Defaults to True when async scheduling is enabled, False otherwise.
    """
194

195
196
197
    ray_workers_use_nsight: bool = False
    """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

198
    ray_runtime_env: RuntimeEnv | None = None
199
200
    """Ray runtime environment to pass to distributed workers."""

201
    placement_group: PlacementGroup | None = None
202
203
    """ray distributed model workers placement group."""

204
    distributed_executor_backend: (
205
        str | DistributedExecutorBackend | type[Executor] | None
206
    ) = None
207
208
209
210
211
212
213
214
    """Backend to use for distributed model workers, either "ray" or "mp"
    (multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size
    is less than or equal to the number of GPUs available, "mp" will be used to
    keep processing on a single host. Otherwise, an error will be raised. To use "mp"
    you must also set nnodes, and to use "ray" you must manually set
    distributed_executor_backend to "ray".

    Note that tpu only support Ray for distributed inference."""
215
216
217
218
219
220
221
222
223
224
225
226

    worker_cls: str = "auto"
    """The full name of the worker class to use. If "auto", the worker class
    will be determined based on the platform."""
    sd_worker_cls: str = "auto"
    """The full name of the worker class to use for speculative decoding.
    If "auto", the worker class will be determined based on the platform."""
    worker_extension_cls: str = ""
    """The full name of the worker extension class to use. The worker extension
    class is dynamically inherited by the worker class. This is used to inject
    new attributes and methods to the worker class for use in collective_rpc
    calls."""
227
228
229
230
231
232
233
234
235
236
    master_addr: str = "127.0.0.1"
    """distributed master address for multi-node distributed 
    inference when distributed_executor_backend is mp."""
    master_port: int = 29501
    """distributed master port for multi-node distributed 
    inference when distributed_executor_backend is mp."""
    node_rank: int = 0
    """distributed node rank for multi-node distributed 
    inference when distributed_executor_backend is mp."""
    nnodes: int = 1
237
    """num of nodes for multi-node distributed
238
    inference when distributed_executor_backend is mp."""
239

240
241
242
243
244
245
    distributed_timeout_seconds: int | None = None
    """Timeout in seconds for distributed operations (e.g., init_process_group).
    If set, this value is passed to torch.distributed.init_process_group as the
    timeout parameter. If None, PyTorch's default timeout is used (600s for NCCL).
    Increase this for multi-node setups where model downloads may be slow."""

246
    world_size: int = Field(init=False)
247
248
249
250
251
    """world_size is TPxPP, it affects the number of workers we create."""

    rank: int = 0
    """Global rank in distributed setup."""

252
    _data_parallel_master_port_list: list[int] = Field(default_factory=list)
253
254
255
256
    """List of open port auto-queried for data parallel messaging.
    Set to be private as it's not intended to be configured by users.
    """

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    _stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
    """List of open ports for stateless DP groups when enable_elastic_ep is True.
    Set to be private as it's not intended to be configured by users.
    It is a list of list[int], with each inner list contains a set of 3 ports
    to be used for setting up the stateless CPU/device/TCPStore groups
    in StatelessGroupCoordinator. The number of inner lists is equal to
    the number of DP groups, 
    i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
    and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
    """

    _stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
    """List of open ports for stateless EP groups when enable_elastic_ep is True.
    Set to be private as it's not intended to be configured by users.
    len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
    """

    _stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list)
    """List of open ports for stateless EPLB groups when enable_elastic_ep is True.
    Same topology as EP but separate NCCL communicator to avoid deadlocks.
    """

    _stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
    """List of open ports for stateless world group when enable_elastic_ep is True.
    Set to be private as it's not intended to be configured by users.
    len(self._stateless_world_group_port_list) == 1,
    """

285
286
287
288
289
    decode_context_parallel_size: int = 1
    """Number of decode context parallel groups, because the world size does
    not change by dcp, it simply reuse the GPUs of TP group, and tp_size
    needs to be divisible by dcp_size."""

290
    dcp_kv_cache_interleave_size: int = 1
291
292
293
294
295
296
    """
    Interleave size of kv_cache storage while using DCP.
    dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size,
    and will be deprecated when PCP is fully supported.

    """
297
298
299
300
301
302
303
304
    dcp_comm_backend: DCPCommBackend = "ag_rs"
    """Communication backend for Decode Context Parallel (DCP).
    - "ag_rs": AllGather + ReduceScatter (default, existing behavior)
    - "a2a": All-to-All exchange of partial outputs + LSE, then
      combine with Triton kernel. Reduces NCCL calls from 3 to 2
      per layer for MLA models.
    """

305
306
307
    cp_kv_cache_interleave_size: int = 1
    """Interleave size of kv_cache storage while using DCP or PCP.
    For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
308
        and `total_cp_world_size = pcp_world_size * dcp_world_size`.
309
    store interleave_size tokens on total_cp_rank i,
310
    then store next interleave_size tokens on total_cp_rank i+1.
311
312
313
314
315
316
317
    Interleave_size=1: token-level alignment, where token `i` is stored on
        total_cp_rank `i % total_cp_world_size`.
    Interleave_size=block_size: block-level alignment, where tokens are
        first populated to the preceding ranks. Tokens are then stored
        in (rank i+1, block j) only after (rank i, block j) is fully occupied.
    Block_size should be greater than or equal to cp_kv_cache_interleave_size.
    Block_size should be divisible by cp_kv_cache_interleave_size.
318
319
    """

320
321
322
323
    data_parallel_index: int = Field(init=False)
    """Equal to the data parallel rank but not used for torch process groups
    and not overridden for dense models."""

324
    _api_process_count: int = Field(default=1, gt=0)
325
326
327
328
329
330
331
332
    """
    The number of API processes initialized.

    Note:
        This is an internal config that is only valid for and
        should only be set by API server scale-out.
    """

333
    _api_process_rank: int = Field(default=0, ge=-1)
334
335
336
337
338
339
340
341
342
    """
    The rank of this API process, or `-1` for engine core processes
    under API server scale-out.

    Note:
        This is an internal config that is only valid for and
        should only be set by API server scale-out.
    """

343
344
345
346
347
348
    @field_validator("disable_nccl_for_dp_synchronization", mode="wrap")
    @classmethod
    def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
        """Skip validation if the value is `None` when initialisation is delayed."""
        return None if value is None else handler(value)

349
350
351
352
353
354
355
356
357
    @model_validator(mode="after")
    def _validate_parallel_config(self) -> Self:
        if self._api_process_rank >= self._api_process_count:
            raise ValueError(
                "Invalid value of `_api_process_rank`. "
                f"Expected to be `-1` or `[0, {self._api_process_count})`, "
                f"but found: {self._api_process_rank}"
            )

358
359
360
361
362
363
364
        if self.all2all_backend == "pplx":
            logger.warning(
                "The 'pplx' all2all backend has been removed. "
                "Falling back to 'allgather_reducescatter'."
            )
            self.all2all_backend = "allgather_reducescatter"

365
366
367
368
369
370
371
372
373
374
375
376
        if self.data_parallel_size_local > self.data_parallel_size:
            raise ValueError(
                f"data_parallel_size_local ({self.data_parallel_size_local}) "
                f"must be <= data_parallel_size ({self.data_parallel_size})"
            )

        if self.data_parallel_size <= 1 and self.data_parallel_external_lb:
            raise ValueError(
                "data_parallel_external_lb can only be set when data_parallel_size > 1"
            )

        if self.enable_eplb:
377
            if not current_platform.is_cuda_alike():
378
379
                raise ValueError(
                    "Expert parallelism load balancing is only supported on "
380
                    "CUDA devices or ROCm devices now."
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
                )
            if not self.enable_expert_parallel:
                raise ValueError("enable_expert_parallel must be True to use EPLB.")
            if self.tensor_parallel_size * self.data_parallel_size <= 1:
                raise ValueError(
                    "EPLB requires tensor_parallel_size or data_parallel_size "
                    f"to be greater than 1, but got "
                    f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
                )
        else:
            if self.eplb_config.num_redundant_experts != 0:
                raise ValueError(
                    "num_redundant_experts is set to "
                    f"{self.eplb_config.num_redundant_experts} but EPLB is not "
                    "enabled. Either enable EPLB or unset "
                    "num_redundant_experts."
                )

399
400
401
402
403
404
405
406
407
408
409
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
        # reuses the GPUs of TP group, and split one TP group into
        # tp_size//dcp_size DCP groups.
        if self.tensor_parallel_size % self.decode_context_parallel_size != 0:
            raise ValueError(
                f"tp_size={self.tensor_parallel_size} must be divisible by"
                f"dcp_size={self.decode_context_parallel_size}."
            )

410
411
412
413
414
        if self.dcp_comm_backend == "a2a" and self.decode_context_parallel_size <= 1:
            raise ValueError(
                "dcp_comm_backend='a2a' requires decode_context_parallel_size > 1."
            )

415
416
        return self

417
418
419
420
421
422
    @property
    def world_size_across_dp(self) -> int:
        """world_size_across_dp is TPxPPxDP, it is the size of the world
        including data parallelism."""
        return self.world_size * self.data_parallel_size

423
424
425
426
427
428
429
430
    @property
    def use_ubatching(self) -> bool:
        return self.enable_dbo or self.ubatch_size > 1

    @property
    def num_ubatches(self) -> int:
        return 2 if self.enable_dbo else self.ubatch_size

431
432
433
434
435
436
437
438
    @property
    def local_engines_only(self) -> bool:
        """
        Client manages local+remote EngineCores in pure internal LB case.
        Client manages local EngineCores in hybrid and external LB case.
        """
        return self.data_parallel_external_lb or self.data_parallel_hybrid_lb

439
440
441
442
443
444
    def get_next_dp_init_port(self) -> int:
        """
        We might need to initialize process groups in multiple
        processes that is related to data parallelism,
        e.g. both in the worker and in the engine, which
        can live in different processes. To avoid port conflicts, we
445
446
        pop a new port from the prepared port list each time we need to
        initialize a new process group related to data parallelism.
447
        """
448
449
450
451
452
453
        if self._data_parallel_master_port_list:
            answer = self._data_parallel_master_port_list.pop()
        else:
            answer = self.data_parallel_master_port
            self.data_parallel_master_port += 1

454
455
        return answer

456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
    def allocate_elastic_ep_ports(self) -> None:
        """Allocate all ports for elastic EP (stateless groups + DP master).

        Must be called AFTER ray.init() so that ports claimed by Ray's
        idle worker pool are already in use and won't be returned by
        get_open_ports_list().
        """
        if not self.enable_elastic_ep:
            return
        if self._stateless_world_group_port_list:
            return

        num_world_groups = 1
        dp_size = self.data_parallel_size
        ep_size = self.data_parallel_size * self.world_size_across_dp
        num_dp_groups = max(1, self.world_size_across_dp // dp_size)
        num_ep_groups = max(1, self.world_size_across_dp // ep_size)
        num_eplb_groups = num_ep_groups
        total_stateless_ports = (
            num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
        ) * 3
        num_dp_master_ports = 5

        all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports)

        self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:]
        self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
        all_ports = all_ports[:-num_dp_master_ports]

        self._stateless_world_group_port_list = [
            all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
        ]
        start_idx = num_world_groups * 3
        self._stateless_dp_group_port_list = [
            all_ports[i : i + 3]
            for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
        ]
        start_idx += num_dp_groups * 3
        self._stateless_ep_group_port_list = [
            all_ports[i : i + 3]
            for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
        ]
        start_idx += num_ep_groups * 3
        self._stateless_eplb_group_port_list = [
            all_ports[i : i + 3]
            for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
        ]

    def get_next_stateless_world_group_port(self) -> list[int]:
        return self._stateless_world_group_port_list.pop()

    def get_next_stateless_dp_group_port(self) -> list[int]:
        return self._stateless_dp_group_port_list.pop()

    def get_next_stateless_ep_group_port(self) -> list[int]:
        return self._stateless_ep_group_port_list.pop()

    def get_next_stateless_eplb_group_port(self) -> list[int]:
        return self._stateless_eplb_group_port_list.pop()

516
517
518
519
520
521
522
523
524
525
526
    @overload
    def stateless_init_dp_group(
        self, return_store: Literal[False] = ...
    ) -> ProcessGroup: ...
    @overload
    def stateless_init_dp_group(
        self, return_store: Literal[True] = ...
    ) -> tuple[ProcessGroup, Store]: ...
    def stateless_init_dp_group(
        self, return_store: bool = False
    ) -> ProcessGroup | tuple[ProcessGroup, Store]:
527
528
529
530
531
532
533
534
535
536
        # NOTE: In high-concurrency scenarios multiple processes
        # can pick the same (currently free) port through a race
        # condition when calling `get_open_port()`. When the first
        # process binds the port the others will subsequently fail
        # with `torch.distributed.DistNetworkError: EADDRINUSE`.
        # To make the initialization more robust we retry a few times
        # with a fresh port whenever this specific error is observed.
        from torch.distributed import DistNetworkError

        from vllm.distributed.utils import (
537
538
            stateless_init_torch_distributed_process_group,
        )
539
540

        max_retries = 5
541
        last_exc: Exception | None = None
542
543
544
545
546
547
548
549
        for _ in range(max_retries):
            try:
                # use gloo since the engine process might not have cuda device
                return stateless_init_torch_distributed_process_group(
                    self.data_parallel_master_ip,
                    self.get_next_dp_init_port(),
                    self.data_parallel_rank,
                    self.data_parallel_size,
550
551
                    backend="gloo",
                    return_store=return_store,
552
                )
553
554
555
            except DistNetworkError as e:
                # We only want to retry when the root cause is EADDRINUSE.
                if "EADDRINUSE" in str(e):
556
                    logger.warning("Address already in use. Retrying with a new port.")
557
558
559
560
561
562
563
564
                    last_exc = e
                    continue  # try again with a new port
                raise e

        # If we get here all retries have failed.
        assert last_exc is not None
        raise last_exc

565
566
567
568
569
570
571
572
573
574
    # The all_reduce at the end of attention (during o_proj) means that
    # inputs are replicated across each rank of the tensor parallel group.
    # If using expert-parallelism with DeepEP All2All ops, replicated
    # tokens results in useless duplicate computation and communication.
    #
    # In this case, ensure the input to the experts is sequence parallel
    # to avoid the excess work.
    #
    @property
    def use_sequence_parallel_moe(self) -> bool:
575
        return (
576
            self.all2all_backend
577
578
579
580
581
            in (
                "allgather_reducescatter",
                "naive",
                "deepep_high_throughput",
                "deepep_low_latency",
582
                "mori",
583
584
585
586
587
            )
            and self.enable_expert_parallel
            and self.tensor_parallel_size > 1
            and self.data_parallel_size > 1
        )
588

589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
    @property
    def node_rank_within_dp(self) -> int:
        return self.node_rank % self.nnodes_within_dp

    @property
    def nnodes_within_dp(self) -> int:
        if self.nnodes == 1:
            return 1
        data_parallel_node_size = (
            self.data_parallel_size // self.data_parallel_size_local
        )
        return self.nnodes // data_parallel_node_size

    @property
    def local_world_size(self) -> int:
        return self.world_size // self.nnodes_within_dp

606
    @staticmethod
607
608
    def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
        tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
609
610
611
612
613
614
615
616
617
        # dp rank 0: has_unfinished_seqs=True
        # dp rank 1: has_unfinished_seqs=False
        # aggregated: has_unfinished_seqs=True
        # so this is an OR operation, i.e. MAX in integers
        torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
        aggregated_has_unfinished = bool(tensor.item())
        return aggregated_has_unfinished

    @staticmethod
618
    def sync_kv_cache_memory_size(dp_group: ProcessGroup, kv_cache_memory: int) -> int:
619
620
        if kv_cache_memory == -1:
            kv_cache_memory = torch.iinfo(torch.int64).max
621
        tensor = torch.tensor([kv_cache_memory], dtype=torch.int64, device="cpu")
622
623
624
625
626
627
628
629
630
631
632
633
        # we cannot use broadcast for stateless dp group since it depends
        # on global rank
        torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
        return tensor.item()

    def compute_hash(self):
        """
        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
634
635
636

        This hash is also used for DP worker configuration validation
        to prevent hangs from mismatched collective communication patterns.
637
        """
638
639
640
641
        ignored_factors = {
            # Derived/runtime topology, networking, or launch details
            "data_parallel_rank",
            "data_parallel_rank_local",
642
            "data_parallel_size_local",
643
            "data_parallel_index",
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
            "data_parallel_backend",
            "data_parallel_external_lb",
            "data_parallel_hybrid_lb",
            "data_parallel_master_ip",
            "data_parallel_master_port",
            "_data_parallel_master_port_list",
            "data_parallel_rpc_port",
            "rank",
            "master_addr",
            "master_port",
            "node_rank",
            "nnodes",
            "max_parallel_loading_workers",
            "disable_custom_all_reduce",
            "ray_workers_use_nsight",
            "ray_runtime_env",
            "placement_group",
            "distributed_executor_backend",
            "worker_cls",
            "sd_worker_cls",
            "worker_extension_cls",
            "_api_process_count",
            "_api_process_rank",
        }

        from vllm.config.utils import get_hash_factors, hash_factors

        factors = get_hash_factors(self, ignored_factors)
        return hash_factors(factors)
673
674

    def __post_init__(self) -> None:
675
        # Continue with the rest of the initialization
676
677
678
679
680
        self.world_size = (
            self.pipeline_parallel_size
            * self.tensor_parallel_size
            * self.prefill_context_parallel_size
        )
681

682
683
684
685
        if self.distributed_executor_backend == "external_launcher":
            logger.info("Using external launcher for distributed inference.")
            self.world_size *= self.data_parallel_size

686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
        if self.enable_elastic_ep:
            if not self.enable_eplb:
                raise ValueError("Elastic EP is only supported with enable_eplb=True.")
            if self.pipeline_parallel_size > 1:
                raise ValueError(
                    "Elastic EP is not supported with pipeline parallelism "
                    f"(pipeline_parallel_size={self.pipeline_parallel_size})."
                )
            if self.data_parallel_external_lb or self.data_parallel_hybrid_lb:
                raise NotImplementedError(
                    "Elastic EP is not compatible with data_parallel_external_lb "
                    "or data_parallel_hybrid_lb. Elastic EP relies on a single API "
                    "server and core client to coordinate scale up/down."
                )

701
702
        if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
            # Data parallel was specified in the engine args.
703
704
705
            if self.distributed_executor_backend == "external_launcher":
                # For external launcher,
                # we need to set the data parallel rank automatically
706
707
708
709
710
711
712
                self.data_parallel_rank = int(os.environ["RANK"]) // (
                    self.world_size // self.data_parallel_size
                )
                logger.info(
                    "Set data_parallel_rank to %d automatically.",
                    self.data_parallel_rank,
                )
713
714
715
716
717
718
            if not self.enable_elastic_ep:
                if not self._data_parallel_master_port_list:
                    self._data_parallel_master_port_list = get_open_ports_list(5)
                self.data_parallel_master_port = (
                    self._data_parallel_master_port_list.pop()
                )
719
720
721
722

            if not (0 <= self.data_parallel_rank < self.data_parallel_size):
                raise ValueError(
                    f"data_parallel_rank ({self.data_parallel_rank})"
723
724
                    f" must be in the range [0, {self.data_parallel_size})"
                )
725
726
727
728
729
730
731
732
        else:
            # Otherwise fall back to env vars (e.g. for offline SPMD case).
            self.data_parallel_size = envs.VLLM_DP_SIZE
            self.data_parallel_rank = envs.VLLM_DP_RANK
            self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
            self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
            self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT

733
734
735
736
737
738
739
740
            if self.data_parallel_size > 1 and self.is_moe_model is False:
                raise ValueError(
                    "Offline data parallel mode is not supported/useful"
                    " for dense models."
                )

        self.data_parallel_index = self.data_parallel_rank

741
742
743
744
        if self.distributed_executor_backend == "external_launcher":
            os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
            logger.info("Disabling V1 multiprocessing for external launcher.")

745
        if self.distributed_executor_backend is None and self.world_size_across_dp > 1:
746
747
748
            # We use multiprocessing by default if world_size fits on the
            # current node and we aren't in a ray placement group.

749
            from vllm.v1.executor import ray_utils
750

751
752
            backend: DistributedExecutorBackend = "mp"
            ray_found = ray_utils.ray_is_available()
753
            if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
754
                backend = "uni"
755
756
            elif current_platform.is_cuda() and self.nnodes > 1:
                backend = "mp"
757
758
759
760
            elif (
                current_platform.is_cuda()
                and cuda_device_count_stateless() < self.world_size
            ):
761
762
                gpu_count = cuda_device_count_stateless()
                raise ValueError(
763
764
765
766
767
                    f"World size ({self.world_size}) is larger than the number of "
                    f"available GPUs ({gpu_count}) in this node. If this is "
                    "intentional and you are using:\n"
                    "- ray, set '--distributed-executor-backend ray'.\n"
                    "- multiprocessing, set '--nnodes' appropriately."
768
                )
769
            elif self.data_parallel_backend == "ray":
770
771
772
773
                logger.info(
                    "Using ray distributed inference because "
                    "data_parallel_backend is ray"
                )
774
775
776
777
778
779
                backend = "ray"
            elif ray_found:
                if self.placement_group:
                    backend = "ray"
                else:
                    from ray import is_initialized as ray_is_initialized
780

781
782
                    if ray_is_initialized():
                        from ray.util import get_current_placement_group
783

784
785
786
                        if get_current_placement_group():
                            backend = "ray"
            self.distributed_executor_backend = backend
787
            logger.debug("Defaulting to use %s for distributed inference", backend)
788
789
790
791

        if self.distributed_executor_backend is None and self.world_size == 1:
            self.distributed_executor_backend = "uni"

792
793
794
795
796
        if self.max_parallel_loading_workers is not None:
            logger.warning(
                "max_parallel_loading_workers is currently "
                "not supported and will be ignored."
            )
797
798
799
800
801
        allowed_backends = ("mp", "uni", "external_launcher")
        if (
            self.distributed_executor_backend not in allowed_backends
            and self.nnodes > 1
        ):
802
            raise ValueError(
803
                "nnodes > 1 can only be set when distributed executor "
804
                "backend is mp, uni or external_launcher."
805
            )
806

807
808
809
810
811
812
813
814
815
816
817
        if (
            self.all2all_backend in ("allgather_reducescatter", "naive")
            and self.eplb_config.use_async
        ):
            logger.warning(
                "Async EPLB causes hangs with the '%s' all2all backend. "
                "Forcing synchronous EPLB.",
                self.all2all_backend,
            )
            self.eplb_config.use_async = False

818
819
820
821
    @property
    def use_ray(self) -> bool:
        return self.distributed_executor_backend == "ray" or (
            isinstance(self.distributed_executor_backend, type)
822
823
            and getattr(self.distributed_executor_backend, "uses_ray", False)
        )
824

825
    @model_validator(mode="after")
826
827
    def _verify_args(self) -> Self:
        # Lazy import to avoid circular import
828
        from vllm.v1.executor import Executor
829
830

        # Enable batch invariance settings if requested
831
        if vllm_is_batch_invariant():
832
            self.disable_custom_all_reduce = True
833
834
835
836
837
838

        if (
            self.distributed_executor_backend is not None
            and not isinstance(self.distributed_executor_backend, str)
            and not (
                isinstance(self.distributed_executor_backend, type)
839
                and issubclass(self.distributed_executor_backend, Executor)
840
841
            )
        ):
842
843
844
            raise ValueError(
                "Unrecognized distributed executor backend "
                f"{self.distributed_executor_backend}. Supported "
845
                "values are 'ray', 'mp' 'uni', 'external_launcher', "
846
                " custom Executor subclass or its import path."
847
            )
848
        if self.use_ray:
849
            from vllm.v1.executor import ray_utils
850

851
852
853
854
855
856
            ray_utils.assert_ray_available()

        if not current_platform.use_custom_allreduce():
            self.disable_custom_all_reduce = True
            logger.debug(
                "Disabled the custom all-reduce kernel because it is not "
857
858
                "supported on current platform."
            )
859
860
861
862
863
        if self.nnodes > 1:
            self.disable_custom_all_reduce = True
            logger.debug(
                "Disabled the custom all-reduce since we are running on multi-node."
            )
864
        if self.ray_workers_use_nsight and not self.use_ray:
865
866
867
            raise ValueError(
                "Unable to use nsight profiling unless workers run with Ray."
            )
868
869

        return self