parallel.py 29.3 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import os
5
from collections.abc import Callable
6
from dataclasses import replace
7
from typing import TYPE_CHECKING, Any, Literal
8
9

import torch
10
from pydantic import Field, field_validator, model_validator
11
12
13
14
15
16
17
from pydantic.dataclasses import dataclass
from torch.distributed import ProcessGroup, ReduceOp
from typing_extensions import Self

import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
18
from vllm.model_executor.layers.batch_invariant import (
19
    vllm_is_batch_invariant,
20
)
21
from vllm.platforms import current_platform
22
from vllm.utils.network_utils import get_open_ports_list
23
from vllm.utils.torch_utils import cuda_device_count_stateless
24
25
26
27
28

if TYPE_CHECKING:
    from ray.runtime_env import RuntimeEnv
    from ray.util.placement_group import PlacementGroup

29
    from vllm.v1.executor import Executor
30
31
32
else:
    RuntimeEnv = Any
    PlacementGroup = Any
33
    Executor = Any
34
35
36

logger = init_logger(__name__)

37
ExpertPlacementStrategy = Literal["linear", "round_robin"]
38
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
39
DataParallelBackend = Literal["ray", "mp"]
Mercykid-bash's avatar
Mercykid-bash committed
40
EPLBPolicyOption = Literal["default"]
41
42
43
44
45
All2AllBackend = Literal[
    "naive",
    "pplx",
    "deepep_high_throughput",
    "deepep_low_latency",
46
    "mori",
47
48
49
    "allgather_reducescatter",
    "flashinfer_all2allv",
]
50
51


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@config
@dataclass
class EPLBConfig:
    """Configuration for Expert Parallel Load Balancing (EP)."""

    window_size: int = 1000
    """Window size for expert load recording."""
    step_interval: int = 3000
    """
    Interval for rearranging experts in expert parallelism.

    Note that if this is greater than the EPLB window size, only the metrics
    of the last `lb_window_size` steps will be used for rearranging experts.
    """

67
    num_redundant_experts: int = Field(default=0, ge=0)
68
69
70
71
72
73
74
    """Number of redundant experts to use for expert parallelism."""

    log_balancedness: bool = False
    """
    Log the balancedness each step of expert parallelism.
    This is turned off by default since it will cause communication overhead.
    """
75
76
77
78
    log_balancedness_interval: int = 1
    """
    Interval for logging the balancedness.
    """
79
80
81
82
    use_async: bool = False
    """
    Whether to use non-blocking EPLB.
    """
83

Mercykid-bash's avatar
Mercykid-bash committed
84
85
86
    policy: EPLBPolicyOption = "default"
    """The policy type for expert parallel load balancing (EPLB)."""

87
88
89
90
91
92
93
94
    @model_validator(mode="after")
    def _validate_eplb_config(self) -> Self:
        if self.use_async and self.policy != "default":
            raise ValueError("Async EPLB is only supported with the default policy.")
        if self.log_balancedness and self.log_balancedness_interval <= 0:
            raise ValueError("log_balancedness_interval must be greater than 0.")
        return self

95

96
97
98
99
100
101
102
103
104
@config
@dataclass
class ParallelConfig:
    """Configuration for the distributed execution."""

    pipeline_parallel_size: int = 1
    """Number of pipeline parallel groups."""
    tensor_parallel_size: int = 1
    """Number of tensor parallel groups."""
105
106
    prefill_context_parallel_size: int = 1
    """Number of prefill context parallel groups."""
107
108
109
110
111
112
113
    data_parallel_size: int = 1
    """Number of data parallel groups. MoE layers will be sharded according to
    the product of the tensor parallel size and data parallel size."""
    data_parallel_size_local: int = 1
    """Number of local data parallel groups."""
    data_parallel_rank: int = 0
    """Rank of the data parallel group."""
114
    data_parallel_rank_local: int | None = None
115
116
117
118
119
120
121
122
    """Local rank of the data parallel group,
    set only in SPMD mode."""
    data_parallel_master_ip: str = "127.0.0.1"
    """IP of the data parallel master."""
    data_parallel_rpc_port: int = 29550
    """Port for data parallel messaging."""
    data_parallel_master_port: int = 29500
    """Port of the data parallel master."""
123
    data_parallel_backend: DataParallelBackend = "mp"
124
125
126
127
    """Backend to use for data parallel, either "mp" or "ray"."""
    data_parallel_external_lb: bool = False
    """Whether to use "external" DP LB mode. Applies only to online serving
    and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
co63oc's avatar
co63oc committed
128
    wide-EP setup in Kubernetes. Set implicitly when --data-parallel-rank
129
130
131
132
133
134
135
136
    is provided explicitly to vllm serve."""
    data_parallel_hybrid_lb: bool = False
    """Whether to use "hybrid" DP LB mode. Applies only to online serving
    and when data_parallel_size > 0. Enables running an AsyncLLM
    and API server on a "per-node" basis where vLLM load balances
    between local data parallel ranks, but an external LB balances
    between vLLM nodes/replicas. Set explicitly in conjunction with
    --data-parallel-start-rank."""
137
138
    is_moe_model: bool | None = None
    """Whether the deployed model is MoE (if known)."""
139
140
141
142
    enable_expert_parallel: bool = False
    """Use expert parallelism instead of tensor parallelism for MoE layers."""
    enable_eplb: bool = False
    """Enable expert parallelism load balancing for MoE layers."""
143
    eplb_config: EPLBConfig = Field(default_factory=EPLBConfig)
144
    """Expert parallelism configuration."""
145
146
147
148
149
150
151
152
153
    expert_placement_strategy: ExpertPlacementStrategy = "linear"
    """The expert placement strategy for MoE layers:\n
    - "linear": Experts are placed in a contiguous manner. For example, with 4
      experts and 2 ranks, rank 0 will have experts [0, 1] and rank 1 will have
      experts [2, 3].\n
    - "round_robin": Experts are placed in a round-robin manner. For example,
      with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
      will have experts [1, 3]. This strategy can help improve load balancing
      for grouped expert models with no redundant experts."""
154
155
156
157
158
159
160
161
    all2all_backend: All2AllBackend = "allgather_reducescatter"
    """All2All backend for MoE expert parallel communication. Available options:

    - "naive": Naive all2all implementation using broadcasts\n
    - "allgather_reducescatter": All2all based on allgather and reducescatter\n
    - "pplx": Use pplx kernels\n
    - "deepep_high_throughput": Use deepep high-throughput kernels\n
    - "deepep_low_latency": Use deepep low-latency kernels\n
162
    - "mori": Use mori kernels\n
163
    - "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
164

165
    max_parallel_loading_workers: int | None = None
166
167
168
169
170
171
172
    """Maximum number of parallel loading workers when loading model
    sequentially in multiple batches. To avoid RAM OOM when using tensor
    parallel and large models."""

    disable_custom_all_reduce: bool = False
    """Disable the custom all-reduce kernel and fall back to NCCL."""

173
    enable_dbo: bool = False
174
    """Enable dual batch overlap for the model executor."""
175
176
    ubatch_size: int = 0
    """Number of ubatch size."""
177
178

    dbo_decode_token_threshold: int = 32
179
180
181
182
183
184
185
186
187
    """The threshold for dual batch overlap for batches only containing decodes.
    If the number of tokens in the request is greater than this threshold,
    microbatching will be used. Otherwise, the request will be processed in a
    single batch."""
    dbo_prefill_token_threshold: int = 512  # TODO(lucas): tune
    """The threshold for dual batch overlap for batches that contain one or more
    prefills. If the number of tokens in the request is greater than this
    threshold, microbatching will be used. Otherwise, the request will be
    processed in a single batch."""
188

189
    disable_nccl_for_dp_synchronization: bool = Field(default=None)
190
    """Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py 
191
192
193
194
    to use Gloo instead of NCCL for its all reduce.

    Defaults to True when async scheduling is enabled, False otherwise.
    """
195

196
197
198
    ray_workers_use_nsight: bool = False
    """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

199
    ray_runtime_env: RuntimeEnv | None = None
200
201
    """Ray runtime environment to pass to distributed workers."""

202
    placement_group: PlacementGroup | None = None
203
204
    """ray distributed model workers placement group."""

205
    distributed_executor_backend: (
206
        str | DistributedExecutorBackend | type[Executor] | None
207
    ) = None
208
209
210
211
212
213
214
215
    """Backend to use for distributed model workers, either "ray" or "mp"
    (multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size
    is less than or equal to the number of GPUs available, "mp" will be used to
    keep processing on a single host. Otherwise, an error will be raised. To use "mp"
    you must also set nnodes, and to use "ray" you must manually set
    distributed_executor_backend to "ray".

    Note that tpu only support Ray for distributed inference."""
216
217
218
219
220
221
222
223
224
225
226
227

    worker_cls: str = "auto"
    """The full name of the worker class to use. If "auto", the worker class
    will be determined based on the platform."""
    sd_worker_cls: str = "auto"
    """The full name of the worker class to use for speculative decoding.
    If "auto", the worker class will be determined based on the platform."""
    worker_extension_cls: str = ""
    """The full name of the worker extension class to use. The worker extension
    class is dynamically inherited by the worker class. This is used to inject
    new attributes and methods to the worker class for use in collective_rpc
    calls."""
228
229
230
231
232
233
234
235
236
237
238
239
    master_addr: str = "127.0.0.1"
    """distributed master address for multi-node distributed 
    inference when distributed_executor_backend is mp."""
    master_port: int = 29501
    """distributed master port for multi-node distributed 
    inference when distributed_executor_backend is mp."""
    node_rank: int = 0
    """distributed node rank for multi-node distributed 
    inference when distributed_executor_backend is mp."""
    nnodes: int = 1
    """num of nodes for multi-node distributed 
    inference when distributed_executor_backend is mp."""
240

241
    world_size: int = Field(init=False)
242
243
244
245
246
    """world_size is TPxPP, it affects the number of workers we create."""

    rank: int = 0
    """Global rank in distributed setup."""

247
    _data_parallel_master_port_list: list[int] = Field(default_factory=list)
248
249
250
251
    """List of open port auto-queried for data parallel messaging.
    Set to be private as it's not intended to be configured by users.
    """

252
253
254
255
256
    decode_context_parallel_size: int = 1
    """Number of decode context parallel groups, because the world size does
    not change by dcp, it simply reuse the GPUs of TP group, and tp_size
    needs to be divisible by dcp_size."""

257
    dcp_kv_cache_interleave_size: int = 1
258
259
260
261
262
263
264
265
266
    """
    Interleave size of kv_cache storage while using DCP.
    dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size,
    and will be deprecated when PCP is fully supported.

    """
    cp_kv_cache_interleave_size: int = 1
    """Interleave size of kv_cache storage while using DCP or PCP.
    For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
267
        and `total_cp_world_size = pcp_world_size * dcp_world_size`.
268
    store interleave_size tokens on total_cp_rank i,
269
    then store next interleave_size tokens on total_cp_rank i+1.
270
271
272
273
274
275
276
    Interleave_size=1: token-level alignment, where token `i` is stored on
        total_cp_rank `i % total_cp_world_size`.
    Interleave_size=block_size: block-level alignment, where tokens are
        first populated to the preceding ranks. Tokens are then stored
        in (rank i+1, block j) only after (rank i, block j) is fully occupied.
    Block_size should be greater than or equal to cp_kv_cache_interleave_size.
    Block_size should be divisible by cp_kv_cache_interleave_size.
277
278
    """

279
280
281
282
    data_parallel_index: int = Field(init=False)
    """Equal to the data parallel rank but not used for torch process groups
    and not overridden for dense models."""

283
    _api_process_count: int = Field(default=1, gt=0)
284
285
286
287
288
289
290
291
    """
    The number of API processes initialized.

    Note:
        This is an internal config that is only valid for and
        should only be set by API server scale-out.
    """

292
    _api_process_rank: int = Field(default=0, ge=-1)
293
294
295
296
297
298
299
300
301
    """
    The rank of this API process, or `-1` for engine core processes
    under API server scale-out.

    Note:
        This is an internal config that is only valid for and
        should only be set by API server scale-out.
    """

302
303
304
305
306
307
    @field_validator("disable_nccl_for_dp_synchronization", mode="wrap")
    @classmethod
    def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
        """Skip validation if the value is `None` when initialisation is delayed."""
        return None if value is None else handler(value)

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    @model_validator(mode="after")
    def _validate_parallel_config(self) -> Self:
        if self._api_process_rank >= self._api_process_count:
            raise ValueError(
                "Invalid value of `_api_process_rank`. "
                f"Expected to be `-1` or `[0, {self._api_process_count})`, "
                f"but found: {self._api_process_rank}"
            )

        if self.data_parallel_size_local > self.data_parallel_size:
            raise ValueError(
                f"data_parallel_size_local ({self.data_parallel_size_local}) "
                f"must be <= data_parallel_size ({self.data_parallel_size})"
            )

        if self.data_parallel_size <= 1 and self.data_parallel_external_lb:
            raise ValueError(
                "data_parallel_external_lb can only be set when data_parallel_size > 1"
            )

        if self.enable_eplb:
329
            if not current_platform.is_cuda_alike():
330
331
                raise ValueError(
                    "Expert parallelism load balancing is only supported on "
332
                    "CUDA devices or ROCm devices now."
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
                )
            if not self.enable_expert_parallel:
                raise ValueError("enable_expert_parallel must be True to use EPLB.")
            if self.tensor_parallel_size * self.data_parallel_size <= 1:
                raise ValueError(
                    "EPLB requires tensor_parallel_size or data_parallel_size "
                    f"to be greater than 1, but got "
                    f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
                )
        else:
            if self.eplb_config.num_redundant_experts != 0:
                raise ValueError(
                    "num_redundant_experts is set to "
                    f"{self.eplb_config.num_redundant_experts} but EPLB is not "
                    "enabled. Either enable EPLB or unset "
                    "num_redundant_experts."
                )

351
352
353
354
355
356
357
358
359
360
361
        # Note(hc): In the current implementation of decode context
        # parallel(DCP), tp_size needs to be divisible by dcp_size,
        # because the world size does not change by dcp, it simply
        # reuses the GPUs of TP group, and split one TP group into
        # tp_size//dcp_size DCP groups.
        if self.tensor_parallel_size % self.decode_context_parallel_size != 0:
            raise ValueError(
                f"tp_size={self.tensor_parallel_size} must be divisible by"
                f"dcp_size={self.decode_context_parallel_size}."
            )

362
363
        return self

364
365
366
367
368
369
    @property
    def world_size_across_dp(self) -> int:
        """world_size_across_dp is TPxPPxDP, it is the size of the world
        including data parallelism."""
        return self.world_size * self.data_parallel_size

370
371
372
373
374
375
376
377
    @property
    def use_ubatching(self) -> bool:
        return self.enable_dbo or self.ubatch_size > 1

    @property
    def num_ubatches(self) -> int:
        return 2 if self.enable_dbo else self.ubatch_size

378
379
380
381
382
383
384
385
    @property
    def local_engines_only(self) -> bool:
        """
        Client manages local+remote EngineCores in pure internal LB case.
        Client manages local EngineCores in hybrid and external LB case.
        """
        return self.data_parallel_external_lb or self.data_parallel_hybrid_lb

386
387
388
389
390
391
    def get_next_dp_init_port(self) -> int:
        """
        We might need to initialize process groups in multiple
        processes that is related to data parallelism,
        e.g. both in the worker and in the engine, which
        can live in different processes. To avoid port conflicts, we
392
393
        pop a new port from the prepared port list each time we need to
        initialize a new process group related to data parallelism.
394
        """
395
396
397
398
399
400
        if self._data_parallel_master_port_list:
            answer = self._data_parallel_master_port_list.pop()
        else:
            answer = self.data_parallel_master_port
            self.data_parallel_master_port += 1

401
402
403
404
405
406
407
408
409
410
411
412
413
        return answer

    def stateless_init_dp_group(self) -> ProcessGroup:
        # NOTE: In high-concurrency scenarios multiple processes
        # can pick the same (currently free) port through a race
        # condition when calling `get_open_port()`. When the first
        # process binds the port the others will subsequently fail
        # with `torch.distributed.DistNetworkError: EADDRINUSE`.
        # To make the initialization more robust we retry a few times
        # with a fresh port whenever this specific error is observed.
        from torch.distributed import DistNetworkError

        from vllm.distributed.utils import (
414
415
            stateless_init_torch_distributed_process_group,
        )
416
417

        max_retries = 5
418
        last_exc: Exception | None = None
419
420
421
422
423
424
425
426
        for _ in range(max_retries):
            try:
                # use gloo since the engine process might not have cuda device
                return stateless_init_torch_distributed_process_group(
                    self.data_parallel_master_ip,
                    self.get_next_dp_init_port(),
                    self.data_parallel_rank,
                    self.data_parallel_size,
427
                    backend=current_platform.dist_backend,
428
                )
429
430
431
            except DistNetworkError as e:
                # We only want to retry when the root cause is EADDRINUSE.
                if "EADDRINUSE" in str(e):
432
                    logger.warning("Address already in use. Retrying with a new port.")
433
434
435
436
437
438
439
440
                    last_exc = e
                    continue  # try again with a new port
                raise e

        # If we get here all retries have failed.
        assert last_exc is not None
        raise last_exc

441
442
443
444
445
446
447
448
449
450
451
    # The all_reduce at the end of attention (during o_proj) means that
    # inputs are replicated across each rank of the tensor parallel group.
    # If using expert-parallelism with DeepEP All2All ops, replicated
    # tokens results in useless duplicate computation and communication.
    #
    # In this case, ensure the input to the experts is sequence parallel
    # to avoid the excess work.
    #
    # Not needed for pplx-kernels as it can handle duplicate input tokens.
    @property
    def use_sequence_parallel_moe(self) -> bool:
452
        return (
453
            self.all2all_backend
454
455
456
457
458
            in (
                "allgather_reducescatter",
                "naive",
                "deepep_high_throughput",
                "deepep_low_latency",
459
                "mori",
460
461
462
463
464
            )
            and self.enable_expert_parallel
            and self.tensor_parallel_size > 1
            and self.data_parallel_size > 1
        )
465

466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    @property
    def node_rank_within_dp(self) -> int:
        return self.node_rank % self.nnodes_within_dp

    @property
    def nnodes_within_dp(self) -> int:
        if self.nnodes == 1:
            return 1
        data_parallel_node_size = (
            self.data_parallel_size // self.data_parallel_size_local
        )
        return self.nnodes // data_parallel_node_size

    @property
    def local_world_size(self) -> int:
        return self.world_size // self.nnodes_within_dp

483
    @staticmethod
484
485
    def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
        tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
486
487
488
489
490
491
492
493
494
        # dp rank 0: has_unfinished_seqs=True
        # dp rank 1: has_unfinished_seqs=False
        # aggregated: has_unfinished_seqs=True
        # so this is an OR operation, i.e. MAX in integers
        torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
        aggregated_has_unfinished = bool(tensor.item())
        return aggregated_has_unfinished

    @staticmethod
495
    def sync_kv_cache_memory_size(dp_group: ProcessGroup, kv_cache_memory: int) -> int:
496
497
        if kv_cache_memory == -1:
            kv_cache_memory = torch.iinfo(torch.int64).max
498
        tensor = torch.tensor([kv_cache_memory], dtype=torch.int64, device="cpu")
499
500
501
502
503
504
505
506
507
508
509
510
        # we cannot use broadcast for stateless dp group since it depends
        # on global rank
        torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
        return tensor.item()

    def compute_hash(self):
        """
        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
511
512
513

        This hash is also used for DP worker configuration validation
        to prevent hangs from mismatched collective communication patterns.
514
        """
515
516
517
518
        ignored_factors = {
            # Derived/runtime topology, networking, or launch details
            "data_parallel_rank",
            "data_parallel_rank_local",
519
            "data_parallel_size_local",
520
            "data_parallel_index",
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
            "data_parallel_backend",
            "data_parallel_external_lb",
            "data_parallel_hybrid_lb",
            "data_parallel_master_ip",
            "data_parallel_master_port",
            "_data_parallel_master_port_list",
            "data_parallel_rpc_port",
            "rank",
            "master_addr",
            "master_port",
            "node_rank",
            "nnodes",
            "max_parallel_loading_workers",
            "disable_custom_all_reduce",
            "ray_workers_use_nsight",
            "ray_runtime_env",
            "placement_group",
            "distributed_executor_backend",
            "worker_cls",
            "sd_worker_cls",
            "worker_extension_cls",
            "_api_process_count",
            "_api_process_rank",
        }

        from vllm.config.utils import get_hash_factors, hash_factors

        factors = get_hash_factors(self, ignored_factors)
        return hash_factors(factors)
550
551

    def __post_init__(self) -> None:
552
        # Continue with the rest of the initialization
553
554
555
556
557
        self.world_size = (
            self.pipeline_parallel_size
            * self.tensor_parallel_size
            * self.prefill_context_parallel_size
        )
558

559
560
561
562
        if self.distributed_executor_backend == "external_launcher":
            logger.info("Using external launcher for distributed inference.")
            self.world_size *= self.data_parallel_size

563
564
        if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
            # Data parallel was specified in the engine args.
565
566
567
            if self.distributed_executor_backend == "external_launcher":
                # For external launcher,
                # we need to set the data parallel rank automatically
568
569
570
571
572
573
574
                self.data_parallel_rank = int(os.environ["RANK"]) // (
                    self.world_size // self.data_parallel_size
                )
                logger.info(
                    "Set data_parallel_rank to %d automatically.",
                    self.data_parallel_rank,
                )
575
576
            if not self._data_parallel_master_port_list:
                self._data_parallel_master_port_list = get_open_ports_list(5)
577
            self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
578
579
580
581

            if not (0 <= self.data_parallel_rank < self.data_parallel_size):
                raise ValueError(
                    f"data_parallel_rank ({self.data_parallel_rank})"
582
583
                    f" must be in the range [0, {self.data_parallel_size})"
                )
584
585
586
587
588
589
590
591
        else:
            # Otherwise fall back to env vars (e.g. for offline SPMD case).
            self.data_parallel_size = envs.VLLM_DP_SIZE
            self.data_parallel_rank = envs.VLLM_DP_RANK
            self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
            self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
            self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT

592
593
594
595
596
597
598
599
            if self.data_parallel_size > 1 and self.is_moe_model is False:
                raise ValueError(
                    "Offline data parallel mode is not supported/useful"
                    " for dense models."
                )

        self.data_parallel_index = self.data_parallel_rank

600
601
602
603
604
605
606
607
        if self.distributed_executor_backend == "external_launcher":
            os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
            logger.info("Disabling V1 multiprocessing for external launcher.")

        if self.distributed_executor_backend is None and self.world_size > 1:
            # We use multiprocessing by default if world_size fits on the
            # current node and we aren't in a ray placement group.

608
            from vllm.v1.executor import ray_utils
609

610
611
            backend: DistributedExecutorBackend = "mp"
            ray_found = ray_utils.ray_is_available()
612
            if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
613
                backend = "uni"
614
615
            elif current_platform.is_cuda() and self.nnodes > 1:
                backend = "mp"
616
617
618
619
            elif (
                current_platform.is_cuda()
                and cuda_device_count_stateless() < self.world_size
            ):
620
621
                gpu_count = cuda_device_count_stateless()
                raise ValueError(
622
623
624
625
626
                    f"World size ({self.world_size}) is larger than the number of "
                    f"available GPUs ({gpu_count}) in this node. If this is "
                    "intentional and you are using:\n"
                    "- ray, set '--distributed-executor-backend ray'.\n"
                    "- multiprocessing, set '--nnodes' appropriately."
627
                )
628
            elif self.data_parallel_backend == "ray":
629
630
631
632
                logger.info(
                    "Using ray distributed inference because "
                    "data_parallel_backend is ray"
                )
633
634
635
636
637
638
                backend = "ray"
            elif ray_found:
                if self.placement_group:
                    backend = "ray"
                else:
                    from ray import is_initialized as ray_is_initialized
639

640
641
                    if ray_is_initialized():
                        from ray.util import get_current_placement_group
642

643
644
645
                        if get_current_placement_group():
                            backend = "ray"
            self.distributed_executor_backend = backend
646
            logger.debug("Defaulting to use %s for distributed inference", backend)
647
648
649
650

        if self.distributed_executor_backend is None and self.world_size == 1:
            self.distributed_executor_backend = "uni"

651
652
653
654
655
        if self.max_parallel_loading_workers is not None:
            logger.warning(
                "max_parallel_loading_workers is currently "
                "not supported and will be ignored."
            )
656
657
658
659
660
        allowed_backends = ("mp", "uni", "external_launcher")
        if (
            self.distributed_executor_backend not in allowed_backends
            and self.nnodes > 1
        ):
661
            raise ValueError(
662
                "nnodes > 1 can only be set when distributed executor "
663
                "backend is mp, uni or external_launcher."
664
            )
665

666
667
668
669
    @property
    def use_ray(self) -> bool:
        return self.distributed_executor_backend == "ray" or (
            isinstance(self.distributed_executor_backend, type)
670
671
            and getattr(self.distributed_executor_backend, "uses_ray", False)
        )
672

673
    @model_validator(mode="after")
674
675
    def _verify_args(self) -> Self:
        # Lazy import to avoid circular import
676
        from vllm.v1.executor import Executor
677
678

        # Enable batch invariance settings if requested
679
        if vllm_is_batch_invariant():
680
            self.disable_custom_all_reduce = True
681
682
683
684
685
686

        if (
            self.distributed_executor_backend is not None
            and not isinstance(self.distributed_executor_backend, str)
            and not (
                isinstance(self.distributed_executor_backend, type)
687
                and issubclass(self.distributed_executor_backend, Executor)
688
689
            )
        ):
690
691
692
            raise ValueError(
                "Unrecognized distributed executor backend "
                f"{self.distributed_executor_backend}. Supported "
693
                "values are 'ray', 'mp' 'uni', 'external_launcher', "
694
                " custom Executor subclass or its import path."
695
            )
696
        if self.use_ray:
697
            from vllm.v1.executor import ray_utils
698

699
700
701
702
703
704
            ray_utils.assert_ray_available()

        if not current_platform.use_custom_allreduce():
            self.disable_custom_all_reduce = True
            logger.debug(
                "Disabled the custom all-reduce kernel because it is not "
705
706
                "supported on current platform."
            )
707
708
709
710
711
        if self.nnodes > 1:
            self.disable_custom_all_reduce = True
            logger.debug(
                "Disabled the custom all-reduce since we are running on multi-node."
            )
712
        if self.ray_workers_use_nsight and not self.use_ray:
713
714
715
            raise ValueError(
                "Unable to use nsight profiling unless workers run with Ray."
            )
716
717

        return self
718
719
720

    def replace(self, **kwargs) -> Self:
        return replace(self, **kwargs)