parallel.py 33.9 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import os
5
import socket
6
from collections.abc import Callable
7
from typing import TYPE_CHECKING, Any, Literal, overload
8
9

import torch
10
from pydantic import Field, field_validator, model_validator
11
from torch.distributed import ProcessGroup, ReduceOp, Store
12
13
14
15
16
17
from typing_extensions import Self

import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.platforms import current_platform
18
from vllm.utils.network_utils import get_open_ports_list
19
from vllm.utils.torch_utils import cuda_device_count_stateless
20
21
22
23
24

if TYPE_CHECKING:
    from ray.runtime_env import RuntimeEnv
    from ray.util.placement_group import PlacementGroup

25
    from vllm.v1.executor import Executor
26
27
28
else:
    RuntimeEnv = Any
    PlacementGroup = Any
29
    Executor = Any
30
31
32

logger = init_logger(__name__)

33
ExpertPlacementStrategy = Literal["linear", "round_robin"]
34
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
35
DataParallelBackend = Literal["ray", "mp"]
Mercykid-bash's avatar
Mercykid-bash committed
36
EPLBPolicyOption = Literal["default"]
37
DCPCommBackend = Literal["ag_rs", "a2a"]
38
39
40
41
42
All2AllBackend = Literal[
    "naive",
    "pplx",
    "deepep_high_throughput",
    "deepep_low_latency",
43
    "mori",
44
    "nixl_ep",
45
    "allgather_reducescatter",
46
47
48
    "flashinfer_all2allv",  # temporary alias for flashinfer_nvlink_two_sided
    "flashinfer_nvlink_two_sided",
    "flashinfer_nvlink_one_sided",
49
]
50
51


52
53
54
55
@config
class EPLBConfig:
    """Configuration for Expert Parallel Load Balancing (EP)."""

56
    window_size: int = Field(default=1000, gt=0)
57
    """Window size for expert load recording."""
58
    step_interval: int = Field(default=3000, gt=0)
59
60
61
62
63
64
65
    """
    Interval for rearranging experts in expert parallelism.

    Note that if this is greater than the EPLB window size, only the metrics
    of the last `lb_window_size` steps will be used for rearranging experts.
    """

66
    num_redundant_experts: int = Field(default=0, ge=0)
67
68
69
70
71
72
73
    """Number of redundant experts to use for expert parallelism."""

    log_balancedness: bool = False
    """
    Log the balancedness each step of expert parallelism.
    This is turned off by default since it will cause communication overhead.
    """
74
    log_balancedness_interval: int = Field(default=1, gt=0)
75
76
77
    """
    Interval for logging the balancedness.
    """
78
79
80
81
    use_async: bool = False
    """
    Whether to use non-blocking EPLB.
    """
82

Mercykid-bash's avatar
Mercykid-bash committed
83
84
85
    policy: EPLBPolicyOption = "default"
    """The policy type for expert parallel load balancing (EPLB)."""

86
87
88
89
90
91
92
93
    @model_validator(mode="after")
    def _validate_eplb_config(self) -> Self:
        if self.use_async and self.policy != "default":
            raise ValueError("Async EPLB is only supported with the default policy.")
        if self.log_balancedness and self.log_balancedness_interval <= 0:
            raise ValueError("log_balancedness_interval must be greater than 0.")
        return self

94

95
96
97
98
99
100
101
102
@config
class ParallelConfig:
    """Configuration for the distributed execution."""

    pipeline_parallel_size: int = 1
    """Number of pipeline parallel groups."""
    tensor_parallel_size: int = 1
    """Number of tensor parallel groups."""
103
104
    prefill_context_parallel_size: int = 1
    """Number of prefill context parallel groups."""
105
106
107
108
109
110
111
    data_parallel_size: int = 1
    """Number of data parallel groups. MoE layers will be sharded according to
    the product of the tensor parallel size and data parallel size."""
    data_parallel_size_local: int = 1
    """Number of local data parallel groups."""
    data_parallel_rank: int = 0
    """Rank of the data parallel group."""
112
    data_parallel_rank_local: int | None = None
113
114
115
116
117
118
119
120
    """Local rank of the data parallel group,
    set only in SPMD mode."""
    data_parallel_master_ip: str = "127.0.0.1"
    """IP of the data parallel master."""
    data_parallel_rpc_port: int = 29550
    """Port for data parallel messaging."""
    data_parallel_master_port: int = 29500
    """Port of the data parallel master."""
121
    data_parallel_backend: DataParallelBackend = "mp"
122
123
124
125
    """Backend to use for data parallel, either "mp" or "ray"."""
    data_parallel_external_lb: bool = False
    """Whether to use "external" DP LB mode. Applies only to online serving
    and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
co63oc's avatar
co63oc committed
126
    wide-EP setup in Kubernetes. Set implicitly when --data-parallel-rank
127
128
129
130
131
132
133
134
    is provided explicitly to vllm serve."""
    data_parallel_hybrid_lb: bool = False
    """Whether to use "hybrid" DP LB mode. Applies only to online serving
    and when data_parallel_size > 0. Enables running an AsyncLLM
    and API server on a "per-node" basis where vLLM load balances
    between local data parallel ranks, but an external LB balances
    between vLLM nodes/replicas. Set explicitly in conjunction with
    --data-parallel-start-rank."""
135
136
    is_moe_model: bool | None = None
    """Whether the deployed model is MoE (if known)."""
137
138
    enable_expert_parallel: bool = False
    """Use expert parallelism instead of tensor parallelism for MoE layers."""
139
140
141
142
143
144
145
    enable_ep_weight_filter: bool = False
    """Skip non-local expert weights during model loading when expert
    parallelism is active.  Each rank only reads its own expert shard from
    disk, which can drastically reduce storage I/O for MoE models with
    per-expert weight tensors (e.g. DeepSeek, Mixtral, Kimi-K2.5).  Has no
    effect on 3D fused-expert checkpoints (e.g. GPT-OSS) or non-MoE
    models."""
146
147
    enable_eplb: bool = False
    """Enable expert parallelism load balancing for MoE layers."""
148
    eplb_config: EPLBConfig = Field(default_factory=EPLBConfig)
149
    """Expert parallelism configuration."""
150
    expert_placement_strategy: ExpertPlacementStrategy = "linear"
151
152
    """The expert placement strategy for MoE layers:

153
154
    - "linear": Experts are placed in a contiguous manner. For example, with 4
      experts and 2 ranks, rank 0 will have experts [0, 1] and rank 1 will have
155
      experts [2, 3].
156
157
158
159
    - "round_robin": Experts are placed in a round-robin manner. For example,
      with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
      will have experts [1, 3]. This strategy can help improve load balancing
      for grouped expert models with no redundant experts."""
160
161
162
    all2all_backend: All2AllBackend = "allgather_reducescatter"
    """All2All backend for MoE expert parallel communication. Available options:

163
164
165
166
167
    - "allgather_reducescatter": All2all based on allgather and reducescatter
    - "deepep_high_throughput": Use deepep high-throughput kernels
    - "deepep_low_latency": Use deepep low-latency kernels
    - "mori": Use mori kernels
    - "nixl_ep": Use nixl-ep kernels
168
169
    - "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl
    - "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels"""
170

171
    max_parallel_loading_workers: int | None = None
172
173
174
175
176
177
178
    """Maximum number of parallel loading workers when loading model
    sequentially in multiple batches. To avoid RAM OOM when using tensor
    parallel and large models."""

    disable_custom_all_reduce: bool = False
    """Disable the custom all-reduce kernel and fall back to NCCL."""

179
180
181
    enable_elastic_ep: bool = False
    """Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""

182
    enable_dbo: bool = False
183
    """Enable dual batch overlap for the model executor."""
184
185
    ubatch_size: int = 0
    """Number of ubatch size."""
186
187

    dbo_decode_token_threshold: int = 32
188
189
190
191
192
193
194
195
196
    """The threshold for dual batch overlap for batches only containing decodes.
    If the number of tokens in the request is greater than this threshold,
    microbatching will be used. Otherwise, the request will be processed in a
    single batch."""
    dbo_prefill_token_threshold: int = 512  # TODO(lucas): tune
    """The threshold for dual batch overlap for batches that contain one or more
    prefills. If the number of tokens in the request is greater than this
    threshold, microbatching will be used. Otherwise, the request will be
    processed in a single batch."""
197

198
    disable_nccl_for_dp_synchronization: bool | None = None
199
    """Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py 
200
201
202
203
    to use Gloo instead of NCCL for its all reduce.

    Defaults to True when async scheduling is enabled, False otherwise.
    """
204

205
206
207
    ray_workers_use_nsight: bool = False
    """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

208
    ray_runtime_env: RuntimeEnv | None = None
209
210
    """Ray runtime environment to pass to distributed workers."""

211
    placement_group: PlacementGroup | None = None
212
213
    """ray distributed model workers placement group."""

214
    distributed_executor_backend: (
215
        str | DistributedExecutorBackend | type[Executor] | None
216
    ) = None
217
218
219
220
221
222
223
224
    """Backend to use for distributed model workers, either "ray" or "mp"
    (multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size
    is less than or equal to the number of GPUs available, "mp" will be used to
    keep processing on a single host. Otherwise, an error will be raised. To use "mp"
    you must also set nnodes, and to use "ray" you must manually set
    distributed_executor_backend to "ray".

    Note that tpu only support Ray for distributed inference."""
225
226
227
228
229
230
231
232
233
234
235
236

    worker_cls: str = "auto"
    """The full name of the worker class to use. If "auto", the worker class
    will be determined based on the platform."""
    sd_worker_cls: str = "auto"
    """The full name of the worker class to use for speculative decoding.
    If "auto", the worker class will be determined based on the platform."""
    worker_extension_cls: str = ""
    """The full name of the worker extension class to use. The worker extension
    class is dynamically inherited by the worker class. This is used to inject
    new attributes and methods to the worker class for use in collective_rpc
    calls."""
237
238
239
240
241
242
243
244
245
246
    master_addr: str = "127.0.0.1"
    """distributed master address for multi-node distributed 
    inference when distributed_executor_backend is mp."""
    master_port: int = 29501
    """distributed master port for multi-node distributed 
    inference when distributed_executor_backend is mp."""
    node_rank: int = 0
    """distributed node rank for multi-node distributed 
    inference when distributed_executor_backend is mp."""
    nnodes: int = 1
247
    """num of nodes for multi-node distributed
248
    inference when distributed_executor_backend is mp."""
249

250
251
252
253
254
255
    distributed_timeout_seconds: int | None = None
    """Timeout in seconds for distributed operations (e.g., init_process_group).
    If set, this value is passed to torch.distributed.init_process_group as the
    timeout parameter. If None, PyTorch's default timeout is used (600s for NCCL).
    Increase this for multi-node setups where model downloads may be slow."""

256
    world_size: int = Field(init=False)
257
258
259
260
261
    """world_size is TPxPP, it affects the number of workers we create."""

    rank: int = 0
    """Global rank in distributed setup."""

262
    _data_parallel_master_port_list: list[int] = Field(default_factory=list)
263
264
265
266
    """List of open port auto-queried for data parallel messaging.
    Set to be private as it's not intended to be configured by users.
    """

267
268
269
    _coord_store_port: int = 0
    """Port of the coordination TCPStore. Can be set by the API server; workers
    connect as clients to exchange self-picked group ports at runtime."""
270

271
272
273
274
275
    decode_context_parallel_size: int = 1
    """Number of decode context parallel groups, because the world size does
    not change by dcp, it simply reuse the GPUs of TP group, and tp_size
    needs to be divisible by dcp_size."""

276
    dcp_kv_cache_interleave_size: int = 1
277
278
279
280
281
282
    """
    Interleave size of kv_cache storage while using DCP.
    dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size,
    and will be deprecated when PCP is fully supported.

    """
283
284
285
286
287
288
289
290
    dcp_comm_backend: DCPCommBackend = "ag_rs"
    """Communication backend for Decode Context Parallel (DCP).
    - "ag_rs": AllGather + ReduceScatter (default, existing behavior)
    - "a2a": All-to-All exchange of partial outputs + LSE, then
      combine with Triton kernel. Reduces NCCL calls from 3 to 2
      per layer for MLA models.
    """

291
292
293
    cp_kv_cache_interleave_size: int = 1
    """Interleave size of kv_cache storage while using DCP or PCP.
    For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
294
        and `total_cp_world_size = pcp_world_size * dcp_world_size`.
295
    store interleave_size tokens on total_cp_rank i,
296
    then store next interleave_size tokens on total_cp_rank i+1.
297
298
299
300
301
302
303
    Interleave_size=1: token-level alignment, where token `i` is stored on
        total_cp_rank `i % total_cp_world_size`.
    Interleave_size=block_size: block-level alignment, where tokens are
        first populated to the preceding ranks. Tokens are then stored
        in (rank i+1, block j) only after (rank i, block j) is fully occupied.
    Block_size should be greater than or equal to cp_kv_cache_interleave_size.
    Block_size should be divisible by cp_kv_cache_interleave_size.
304
305
    """

306
307
308
309
    data_parallel_index: int = Field(init=False)
    """Equal to the data parallel rank but not used for torch process groups
    and not overridden for dense models."""

310
    _api_process_count: int = Field(default=1, gt=0)
311
312
313
314
315
316
317
318
    """
    The number of API processes initialized.

    Note:
        This is an internal config that is only valid for and
        should only be set by API server scale-out.
    """

319
    _api_process_rank: int = Field(default=0, ge=-1)
320
321
322
323
324
325
326
327
328
    """
    The rank of this API process, or `-1` for engine core processes
    under API server scale-out.

    Note:
        This is an internal config that is only valid for and
        should only be set by API server scale-out.
    """

329
330
331
332
333
334
    @field_validator("disable_nccl_for_dp_synchronization", mode="wrap")
    @classmethod
    def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
        """Skip validation if the value is `None` when initialisation is delayed."""
        return None if value is None else handler(value)

335
336
337
338
339
340
341
342
343
    @model_validator(mode="after")
    def _validate_parallel_config(self) -> Self:
        if self._api_process_rank >= self._api_process_count:
            raise ValueError(
                "Invalid value of `_api_process_rank`. "
                f"Expected to be `-1` or `[0, {self._api_process_count})`, "
                f"but found: {self._api_process_rank}"
            )

344
        if self.all2all_backend in ["pplx", "naive"]:
345
            logger.warning(
346
347
348
                "The '%s' all2all backend has been removed. "
                "Falling back to 'allgather_reducescatter'.",
                self.all2all_backend,
349
350
351
            )
            self.all2all_backend = "allgather_reducescatter"

352
353
354
355
356
357
358
359
360
361
362
363
        if self.data_parallel_size_local > self.data_parallel_size:
            raise ValueError(
                f"data_parallel_size_local ({self.data_parallel_size_local}) "
                f"must be <= data_parallel_size ({self.data_parallel_size})"
            )

        if self.data_parallel_size <= 1 and self.data_parallel_external_lb:
            raise ValueError(
                "data_parallel_external_lb can only be set when data_parallel_size > 1"
            )

        if self.enable_eplb:
364
            if not current_platform.is_cuda_alike():
365
366
                raise ValueError(
                    "Expert parallelism load balancing is only supported on "
367
                    "CUDA devices or ROCm devices now."
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
                )
            if not self.enable_expert_parallel:
                raise ValueError("enable_expert_parallel must be True to use EPLB.")
            if self.tensor_parallel_size * self.data_parallel_size <= 1:
                raise ValueError(
                    "EPLB requires tensor_parallel_size or data_parallel_size "
                    f"to be greater than 1, but got "
                    f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
                )
        else:
            if self.eplb_config.num_redundant_experts != 0:
                raise ValueError(
                    "num_redundant_experts is set to "
                    f"{self.eplb_config.num_redundant_experts} but EPLB is not "
                    "enabled. Either enable EPLB or unset "
                    "num_redundant_experts."
                )

386
387
388
389
390
391
392
393
394
395
396
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
        # reuses the GPUs of TP group, and split one TP group into
        # tp_size//dcp_size DCP groups.
        if self.tensor_parallel_size % self.decode_context_parallel_size != 0:
            raise ValueError(
                f"tp_size={self.tensor_parallel_size} must be divisible by"
                f"dcp_size={self.decode_context_parallel_size}."
            )

397
398
399
400
401
        if self.dcp_comm_backend == "a2a" and self.decode_context_parallel_size <= 1:
            raise ValueError(
                "dcp_comm_backend='a2a' requires decode_context_parallel_size > 1."
            )

402
403
        return self

404
405
406
407
408
409
    @property
    def world_size_across_dp(self) -> int:
        """world_size_across_dp is TPxPPxDP, it is the size of the world
        including data parallelism."""
        return self.world_size * self.data_parallel_size

410
411
412
413
414
415
416
417
    @property
    def use_ubatching(self) -> bool:
        return self.enable_dbo or self.ubatch_size > 1

    @property
    def num_ubatches(self) -> int:
        return 2 if self.enable_dbo else self.ubatch_size

418
419
420
421
422
423
424
425
    @property
    def local_engines_only(self) -> bool:
        """
        Client manages local+remote EngineCores in pure internal LB case.
        Client manages local EngineCores in hybrid and external LB case.
        """
        return self.data_parallel_external_lb or self.data_parallel_hybrid_lb

426
427
428
429
430
431
    def get_next_dp_init_port(self) -> int:
        """
        We might need to initialize process groups in multiple
        processes that is related to data parallelism,
        e.g. both in the worker and in the engine, which
        can live in different processes. To avoid port conflicts, we
432
433
        pop a new port from the prepared port list each time we need to
        initialize a new process group related to data parallelism.
434
        """
435
436
437
438
439
440
        if self._data_parallel_master_port_list:
            answer = self._data_parallel_master_port_list.pop()
        else:
            answer = self.data_parallel_master_port
            self.data_parallel_master_port += 1

441
442
        return answer

443
444
    def _pick_stateless_dp_port(self) -> tuple[int, socket.socket | None]:
        """Return ``(port, listen_socket)`` for DP group init.
445

446
447
448
        With a coord store, rank 0 binds a socket and publishes the port;
        others read it.  Without one, pops a pre-allocated port and
        returns ``listen_socket=None``.
449
        """
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
        if not self._coord_store_port:
            return self.get_next_dp_init_port(), None

        from vllm.distributed.utils import get_cached_tcp_store_client

        store = get_cached_tcp_store_client(
            self.data_parallel_master_ip, self._coord_store_port
        )

        key = "dp_master_port"
        if self.data_parallel_rank == 0:
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.bind((self.data_parallel_master_ip, 0))
            s.listen()
            port = s.getsockname()[1]
            store.set(key, str(port).encode())
            return port, s
        else:
            return int(store.get(key).decode()), None
469

470
471
472
473
474
475
476
477
478
479
480
    @overload
    def stateless_init_dp_group(
        self, return_store: Literal[False] = ...
    ) -> ProcessGroup: ...
    @overload
    def stateless_init_dp_group(
        self, return_store: Literal[True] = ...
    ) -> tuple[ProcessGroup, Store]: ...
    def stateless_init_dp_group(
        self, return_store: bool = False
    ) -> ProcessGroup | tuple[ProcessGroup, Store]:
481
482
483
484
485
486
487
488
489
490
        # NOTE: In high-concurrency scenarios multiple processes
        # can pick the same (currently free) port through a race
        # condition when calling `get_open_port()`. When the first
        # process binds the port the others will subsequently fail
        # with `torch.distributed.DistNetworkError: EADDRINUSE`.
        # To make the initialization more robust we retry a few times
        # with a fresh port whenever this specific error is observed.
        from torch.distributed import DistNetworkError

        from vllm.distributed.utils import (
491
492
            stateless_init_torch_distributed_process_group,
        )
493
494

        max_retries = 5
495
        last_exc: Exception | None = None
496
497
        for _ in range(max_retries):
            try:
498
                port, listen_socket = self._pick_stateless_dp_port()
499
500
501
                # use gloo since the engine process might not have cuda device
                return stateless_init_torch_distributed_process_group(
                    self.data_parallel_master_ip,
502
                    port,
503
504
                    self.data_parallel_rank,
                    self.data_parallel_size,
505
506
                    backend="gloo",
                    return_store=return_store,
507
                    listen_socket=listen_socket,
508
                )
509
510
511
            except DistNetworkError as e:
                # We only want to retry when the root cause is EADDRINUSE.
                if "EADDRINUSE" in str(e):
512
                    logger.warning("Address already in use. Retrying with a new port.")
513
514
515
516
517
518
519
520
                    last_exc = e
                    continue  # try again with a new port
                raise e

        # If we get here all retries have failed.
        assert last_exc is not None
        raise last_exc

521
522
523
524
525
526
527
528
529
530
    # The all_reduce at the end of attention (during o_proj) means that
    # inputs are replicated across each rank of the tensor parallel group.
    # If using expert-parallelism with DeepEP All2All ops, replicated
    # tokens results in useless duplicate computation and communication.
    #
    # In this case, ensure the input to the experts is sequence parallel
    # to avoid the excess work.
    #
    @property
    def use_sequence_parallel_moe(self) -> bool:
531
        return (
532
            self.all2all_backend
533
534
535
536
            in (
                "allgather_reducescatter",
                "deepep_high_throughput",
                "deepep_low_latency",
537
                "mori",
538
                "nixl_ep",
539
540
541
542
543
            )
            and self.enable_expert_parallel
            and self.tensor_parallel_size > 1
            and self.data_parallel_size > 1
        )
544

545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    @property
    def node_rank_within_dp(self) -> int:
        return self.node_rank % self.nnodes_within_dp

    @property
    def nnodes_within_dp(self) -> int:
        if self.nnodes == 1:
            return 1
        data_parallel_node_size = (
            self.data_parallel_size // self.data_parallel_size_local
        )
        return self.nnodes // data_parallel_node_size

    @property
    def local_world_size(self) -> int:
        return self.world_size // self.nnodes_within_dp

562
    @staticmethod
563
564
    def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
        tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
565
566
567
568
569
570
571
572
573
        # dp rank 0: has_unfinished_seqs=True
        # dp rank 1: has_unfinished_seqs=False
        # aggregated: has_unfinished_seqs=True
        # so this is an OR operation, i.e. MAX in integers
        torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
        aggregated_has_unfinished = bool(tensor.item())
        return aggregated_has_unfinished

    @staticmethod
574
    def sync_kv_cache_memory_size(dp_group: ProcessGroup, kv_cache_memory: int) -> int:
575
576
        if kv_cache_memory == -1:
            kv_cache_memory = torch.iinfo(torch.int64).max
577
        tensor = torch.tensor([kv_cache_memory], dtype=torch.int64, device="cpu")
578
579
580
581
582
583
584
585
586
587
588
589
        # we cannot use broadcast for stateless dp group since it depends
        # on global rank
        torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
        return tensor.item()

    def compute_hash(self):
        """
        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
590
591
592

        This hash is also used for DP worker configuration validation
        to prevent hangs from mismatched collective communication patterns.
593
        """
594
595
596
597
        ignored_factors = {
            # Derived/runtime topology, networking, or launch details
            "data_parallel_rank",
            "data_parallel_rank_local",
598
            "data_parallel_size_local",
599
            "data_parallel_index",
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
            "data_parallel_backend",
            "data_parallel_external_lb",
            "data_parallel_hybrid_lb",
            "data_parallel_master_ip",
            "data_parallel_master_port",
            "_data_parallel_master_port_list",
            "data_parallel_rpc_port",
            "rank",
            "master_addr",
            "master_port",
            "node_rank",
            "nnodes",
            "max_parallel_loading_workers",
            "disable_custom_all_reduce",
            "ray_workers_use_nsight",
            "ray_runtime_env",
            "placement_group",
            "distributed_executor_backend",
            "worker_cls",
            "sd_worker_cls",
            "worker_extension_cls",
            "_api_process_count",
            "_api_process_rank",
        }

        from vllm.config.utils import get_hash_factors, hash_factors

        factors = get_hash_factors(self, ignored_factors)
        return hash_factors(factors)
629
630

    def __post_init__(self) -> None:
631
        # Continue with the rest of the initialization
632
633
634
635
636
        self.world_size = (
            self.pipeline_parallel_size
            * self.tensor_parallel_size
            * self.prefill_context_parallel_size
        )
637

638
639
640
641
        if self.distributed_executor_backend == "external_launcher":
            logger.info("Using external launcher for distributed inference.")
            self.world_size *= self.data_parallel_size

642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
        if self.enable_elastic_ep:
            if not self.enable_eplb:
                raise ValueError("Elastic EP is only supported with enable_eplb=True.")
            if self.pipeline_parallel_size > 1:
                raise ValueError(
                    "Elastic EP is not supported with pipeline parallelism "
                    f"(pipeline_parallel_size={self.pipeline_parallel_size})."
                )
            if self.data_parallel_external_lb or self.data_parallel_hybrid_lb:
                raise NotImplementedError(
                    "Elastic EP is not compatible with data_parallel_external_lb "
                    "or data_parallel_hybrid_lb. Elastic EP relies on a single API "
                    "server and core client to coordinate scale up/down."
                )

657
658
        if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
            # Data parallel was specified in the engine args.
659
660
661
            if self.distributed_executor_backend == "external_launcher":
                # For external launcher,
                # we need to set the data parallel rank automatically
662
663
664
665
666
667
668
                self.data_parallel_rank = int(os.environ["RANK"]) // (
                    self.world_size // self.data_parallel_size
                )
                logger.info(
                    "Set data_parallel_rank to %d automatically.",
                    self.data_parallel_rank,
                )
669
670
671
672
673
674
            if not self.enable_elastic_ep:
                if not self._data_parallel_master_port_list:
                    self._data_parallel_master_port_list = get_open_ports_list(5)
                self.data_parallel_master_port = (
                    self._data_parallel_master_port_list.pop()
                )
675
676
677
678

            if not (0 <= self.data_parallel_rank < self.data_parallel_size):
                raise ValueError(
                    f"data_parallel_rank ({self.data_parallel_rank})"
679
680
                    f" must be in the range [0, {self.data_parallel_size})"
                )
681
682
683
684
685
686
687
688
        else:
            # Otherwise fall back to env vars (e.g. for offline SPMD case).
            self.data_parallel_size = envs.VLLM_DP_SIZE
            self.data_parallel_rank = envs.VLLM_DP_RANK
            self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
            self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
            self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT

689
690
691
692
693
694
695
696
            if self.data_parallel_size > 1 and self.is_moe_model is False:
                raise ValueError(
                    "Offline data parallel mode is not supported/useful"
                    " for dense models."
                )

        self.data_parallel_index = self.data_parallel_rank

697
698
699
700
        if self.distributed_executor_backend == "external_launcher":
            os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
            logger.info("Disabling V1 multiprocessing for external launcher.")

701
        if self.distributed_executor_backend is None and self.world_size_across_dp > 1:
702
703
704
            # We use multiprocessing by default if world_size fits on the
            # current node and we aren't in a ray placement group.

705
            from vllm.v1.executor import ray_utils
706

707
708
            backend: DistributedExecutorBackend = "mp"
            ray_found = ray_utils.ray_is_available()
709
            if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
710
                backend = "uni"
711
712
            elif current_platform.is_cuda() and self.nnodes > 1:
                backend = "mp"
713
714
715
716
            elif (
                current_platform.is_cuda()
                and cuda_device_count_stateless() < self.world_size
            ):
717
718
                gpu_count = cuda_device_count_stateless()
                raise ValueError(
719
720
721
722
723
                    f"World size ({self.world_size}) is larger than the number of "
                    f"available GPUs ({gpu_count}) in this node. If this is "
                    "intentional and you are using:\n"
                    "- ray, set '--distributed-executor-backend ray'.\n"
                    "- multiprocessing, set '--nnodes' appropriately."
724
                )
725
            elif self.data_parallel_backend == "ray":
726
727
728
729
                logger.info(
                    "Using ray distributed inference because "
                    "data_parallel_backend is ray"
                )
730
731
732
733
734
735
                backend = "ray"
            elif ray_found:
                if self.placement_group:
                    backend = "ray"
                else:
                    from ray import is_initialized as ray_is_initialized
736

737
738
                    if ray_is_initialized():
                        from ray.util import get_current_placement_group
739

740
741
742
                        if get_current_placement_group():
                            backend = "ray"
            self.distributed_executor_backend = backend
743
            logger.debug("Defaulting to use %s for distributed inference", backend)
744
745
746
747

        if self.distributed_executor_backend is None and self.world_size == 1:
            self.distributed_executor_backend = "uni"

748
749
750
751
752
        if self.max_parallel_loading_workers is not None:
            logger.warning(
                "max_parallel_loading_workers is currently "
                "not supported and will be ignored."
            )
753
754
755
756
757
        allowed_backends = ("mp", "uni", "external_launcher")
        if (
            self.distributed_executor_backend not in allowed_backends
            and self.nnodes > 1
        ):
758
            raise ValueError(
759
                "nnodes > 1 can only be set when distributed executor "
760
                "backend is mp, uni or external_launcher."
761
            )
762

763
        if (
764
            self.all2all_backend in ("allgather_reducescatter")
765
766
767
768
769
770
771
772
773
            and self.eplb_config.use_async
        ):
            logger.warning(
                "Async EPLB causes hangs with the '%s' all2all backend. "
                "Forcing synchronous EPLB.",
                self.all2all_backend,
            )
            self.eplb_config.use_async = False

774
775
776
777
    @property
    def use_ray(self) -> bool:
        return self.distributed_executor_backend == "ray" or (
            isinstance(self.distributed_executor_backend, type)
778
779
            and getattr(self.distributed_executor_backend, "uses_ray", False)
        )
780

781
    @model_validator(mode="after")
782
783
    def _verify_args(self) -> Self:
        # Lazy import to avoid circular import
784
        from vllm.v1.executor import Executor
785
786

        # Enable batch invariance settings if requested
787
        if envs.VLLM_BATCH_INVARIANT:
788
            self.disable_custom_all_reduce = True
789
790
791
792
793
794

        if (
            self.distributed_executor_backend is not None
            and not isinstance(self.distributed_executor_backend, str)
            and not (
                isinstance(self.distributed_executor_backend, type)
795
                and issubclass(self.distributed_executor_backend, Executor)
796
797
            )
        ):
798
799
800
            raise ValueError(
                "Unrecognized distributed executor backend "
                f"{self.distributed_executor_backend}. Supported "
801
                "values are 'ray', 'mp' 'uni', 'external_launcher', "
802
                " custom Executor subclass or its import path."
803
            )
804
        if self.use_ray:
805
            from vllm.v1.executor import ray_utils
806

807
808
809
810
811
812
            ray_utils.assert_ray_available()

        if not current_platform.use_custom_allreduce():
            self.disable_custom_all_reduce = True
            logger.debug(
                "Disabled the custom all-reduce kernel because it is not "
813
814
                "supported on current platform."
            )
815
816
817
818
819
        if self.nnodes > 1:
            self.disable_custom_all_reduce = True
            logger.debug(
                "Disabled the custom all-reduce since we are running on multi-node."
            )
820
        if self.ray_workers_use_nsight and not self.use_ray:
821
822
823
            raise ValueError(
                "Unable to use nsight profiling unless workers run with Ray."
            )
824
825

        return self