"vscode:/vscode.git/clone" did not exist on "d43ad5a75790e4d97394940187bbf37402c4fa97"
ray_executor.py 25.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from collections import defaultdict
6
from collections.abc import Callable
7
from concurrent.futures import Future
8
from dataclasses import dataclass
9
from typing import TYPE_CHECKING, Any
10

11
import cloudpickle
12

13
import vllm.envs as envs
14
from vllm.logger import init_logger
15
from vllm.platforms import current_platform
16
from vllm.ray.ray_env import get_env_vars_to_copy
17
from vllm.utils.network_utils import (
18
19
20
21
    get_distributed_init_method,
    get_ip,
    get_open_port,
)
22
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
23
24
25
26
27
28
29
30
31
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (
    FutureWrapper,
    RayWorkerWrapper,
    initialize_ray_cluster,
    ray,
)
from vllm.v1.outputs import ModelRunnerOutput
32
33

if ray is not None:
34
    from ray.actor import ActorHandle
35
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
36
37
else:
    ActorHandle = None
38
39
40
41
42
43

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)

44
45
46
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
COMPLETED_NONE_FUTURE.set_result(None)

47

48
49
50
51
52
53
54
@dataclass
class RayWorkerMetaData:
    """
    Metadata for a Ray worker.
    The order of ray worker creation can be random,
    and we need to reset the rank after creating all workers.
    """
55

56
57
58
59
60
61
    worker: ActorHandle
    created_rank: int
    adjusted_rank: int = -1
    ip: str = ""


62
class RayDistributedExecutor(Executor):
63
64
65
66
67
    """Ray-based distributed executor"""

    # These env vars are worker-specific, therefore are NOT copied
    # from the driver to the workers
    WORKER_SPECIFIC_ENV_VARS = {
68
69
70
71
        "VLLM_HOST_IP",
        "VLLM_HOST_PORT",
        "LOCAL_RANK",
        "CUDA_VISIBLE_DEVICES",
72
73
        "HIP_VISIBLE_DEVICES",
        "ROCR_VISIBLE_DEVICES",
74
75
    }

76
    uses_ray: bool = True
77
    supports_pp: bool = True
78

79
    def _init_executor(self) -> None:
80
        self.forward_dag: ray.dag.CompiledDAG | None = None
81
82
83
84

        # For TPU or XPU, avoid compiling NVIDIA's NCCL
        if current_platform.is_tpu() or current_platform.is_xpu():
            os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
85

86
        assert self.uses_ray
87
        initialize_ray_cluster(self.parallel_config)
88
89
90
91
92
93
94
95
96
97
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

98
99
        # KV connector setup
        self.has_connector = self.vllm_config.kv_transfer_config is not None
100

101
102
103
        self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
            self.vllm_config.ec_transfer_config is None
            or not self.vllm_config.ec_transfer_config.is_ec_producer
104
105
        )

106
107
        self.scheduler_output: SchedulerOutput | None = None

108
109
110
111
112
    @property
    def max_concurrent_batches(self) -> int:
        """Ray distributed executor supports pipeline parallelism,
        meaning that it allows PP size batches to be executed concurrently.
        """
113
114
        pp_size = self.parallel_config.pipeline_parallel_size
        return 2 if pp_size <= 1 and self.scheduler_config.async_scheduling else pp_size
115

116
    def shutdown(self) -> None:
117
118
119
120
121
        if logger:
            # Somehow logger can be None here.
            logger.info(
                "Shutting down Ray distributed executor. If you see error log "
                "from logging.cc regarding SIGTERM received, please ignore "
122
123
                "because this is the expected termination process in Ray."
            )
124
125
126
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
127

128
129
130
131
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

132
    def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]:
133
134
135
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
136
137
138
139
140
141
142
        runtime_env.update(
            {
                "nsight": {
                    "t": "cuda,cudnn,cublas",
                    "o": "'worker_process_%p'",
                    "cuda-graph-trace": "node",
                }
143
            }
144
        )
145
146
147

        return ray_remote_kwargs

148
149
150
151
152
153
154
155
    def _update_noset_device_env_vars(self, ray_remote_kwargs):
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
        env_vars = runtime_env.setdefault("env_vars", {})
        env_vars.update(
            {env_var: "1" for env_var in current_platform.ray_noset_device_env_vars}
        )
        return ray_remote_kwargs

156
157
158
159
    # child class could overwrite this to return actual env vars.
    def _get_env_vars_to_be_updated(self):
        return self._env_vars_for_all_workers

160
    def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
161
        num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
162
163
164

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
165
        self.driver_dummy_worker: RayWorkerWrapper | None = None
166
        # The remaining workers are the actual ray actors.
167
        self.workers: list[RayWorkerWrapper] = []
168

169
170
171
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
172
        self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
173

174
175
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
176
177
                ray_remote_kwargs
            )
178

179
180
181
182
183
        # The way ray actors are setup in vllm is that the visible devices are
        # not set by actors, they are left unset by ray. Internally we index
        # the right gpu with local_rank. This is similar to how mp mode works.
        self._update_noset_device_env_vars(ray_remote_kwargs)

184
        # Create the workers.
185
        bundle_indices: list[int]
186
187
        if envs.VLLM_RAY_BUNDLE_INDICES:
            # Use the bundle indices specified by the user.
188
189
190
191
192
193
194
195
196
197
            bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
            assert len(bundle_indices) == self.parallel_config.world_size, (
                "VLLM_RAY_BUNDLE_INDICES must have the same size"
                f" as the world size, but got {bundle_indices=} "
                f"and {self.parallel_config.world_size=}"
            )
            assert len(set(bundle_indices)) == len(bundle_indices), (
                "VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
                f" but got {bundle_indices=}"
            )
198
199
200
201
202
203
        else:
            # use the first N bundles that have GPU resources.
            bundle_indices = []
            for bundle_id, bundle in enumerate(placement_group.bundle_specs):
                if bundle.get(current_platform.ray_device_key, 0):
                    bundle_indices.append(bundle_id)
204
            bundle_indices = bundle_indices[: self.parallel_config.world_size]
205

206
        worker_metadata: list[RayWorkerMetaData] = []
207
208
        driver_ip = get_ip()
        for rank, bundle_id in enumerate(bundle_indices):
209
210
211
212
213
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
214

215
216
217
218
219
220
221
            if current_platform.ray_device_key == "GPU":
                # NV+AMD GPUs, and Intel XPUs
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
222
                )(RayWorkerWrapper).remote(rpc_rank=rank)
223
224
225
226
227
228
229
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={current_platform.ray_device_key: num_gpus},
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
230
231
                )(RayWorkerWrapper).remote(rpc_rank=rank)

232
            worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
233

234
235
236
237
238
239
        worker_ips = ray.get(
            [
                each.worker.get_node_ip.remote()  # type: ignore[attr-defined]
                for each in worker_metadata
            ]
        )
240
241
242

        for each, ip in zip(worker_metadata, worker_ips):
            each.ip = ip
243

244
        logger.debug("workers: %s", worker_metadata)
245
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
246

247
        ip_counts: dict[str, int] = {}
248
249
250
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

251
        def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
252
253
254
255
256
257
258
259
260
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
261
            ip = item.ip
262
            return 0 if ip == driver_ip else 1, ip_counts[ip], ip
263
264
265
266

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
267
268
269
        sorted_worker_metadata = sorted(
            worker_metadata, key=sort_by_driver_then_worker_ip
        )
270
        for i, item in enumerate(sorted_worker_metadata):
271
            item.adjusted_rank = i
272
273
        self.workers = [item.worker for item in sorted_worker_metadata]
        rerank_mapping = {
274
            item.created_rank: item.adjusted_rank for item in sorted_worker_metadata
275
        }
276
        self.collective_rpc("adjust_rank", args=(rerank_mapping,))
277

278
        # Get the set of GPU IDs used on each node.
279
280
281
282
283
284
        worker_node_and_gpu_ids = []
        for worker in [self.driver_dummy_worker] + self.workers:
            if worker is None:
                # driver_dummy_worker can be None when using ray spmd worker.
                continue
            worker_node_and_gpu_ids.append(
285
                ray.get(worker.get_node_and_gpu_ids.remote())
286
            )  # type: ignore[attr-defined]
287

288
289
290
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

291
292
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
293
294
295
296
297
298
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
299
300
301
302
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

303
304
305
306
307
308
309
310
311
        all_ips = set(worker_ips + [driver_ip])
        n_ips = len(all_ips)
        n_nodes = len(node_workers)

        if n_nodes != n_ips:
            raise RuntimeError(
                f"Every node should have a unique IP address. Got {n_nodes}"
                f" nodes with node ids {list(node_workers.keys())} and "
                f"{n_ips} unique IP addresses {all_ips}. Please check your"
312
313
                " network configuration. If you set `VLLM_HOST_IP`"
                " environment variable, make sure it is unique for"
314
315
                " each node."
            )
316

317
        # Set environment variables for the driver and workers.
318
319
320
321
322
323
324
325
326
        # We set CUDA_VISIBLE_DEVICES to ALL GPUs on the node for each worker.
        # This is needed because:
        # 1. Ray's compiled DAG needs to find the allocated GPU in
        #    CUDA_VISIBLE_DEVICES.
        # 2. vLLM's communication layer (NCCL, CustomAllreduce) needs to see
        #    all GPUs for P2P checks and communication setup. Though if it was
        #    just this reason, we could have also just kept the visible devices
        #    unset.
        # Each worker will use local_rank to index into the visible devices.
327
328
329
330
331
332
333
334
        all_args_to_update_environment_variables = [
            {
                current_platform.device_control_env_var: ",".join(
                    map(str, node_gpus[node_id])
                ),
            }
            for (node_id, _) in worker_node_and_gpu_ids
        ]
335

336
        # Environment variables to copy from driver to workers
337
338
        env_vars_to_copy = get_env_vars_to_copy(
            exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
339
            additional_vars=set(current_platform.additional_env_vars),
340
341
            destination="workers",
        )
342

343
        # Copy existing env vars to each worker's args
344
345
        for args in all_args_to_update_environment_variables:
            # TODO: refactor platform-specific env vars
346
            for name in env_vars_to_copy:
347
348
                if name in os.environ:
                    args[name] = os.environ[name]
349

350
        self._env_vars_for_all_workers = all_args_to_update_environment_variables
351

352
353
        self.collective_rpc(
            "update_environment_variables", args=(self._get_env_vars_to_be_updated(),)
354
        )
355

356
357
358
359
360
361
362
363
364
365
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
366
        distributed_init_method = get_distributed_init_method(
367
368
            driver_ip, get_open_port()
        )
369

370
        # Initialize the actual workers inside worker wrapper.
371
372
373
374
375
376
        all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
            local_rank = node_workers[node_id].index(rank)
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
377
378
                rank=rank,
                distributed_init_method=distributed_init_method,
379
380
381
382
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        self.collective_rpc("init_worker", args=(all_kwargs,))

        self.collective_rpc("init_device")
        self.collective_rpc("load_model")

        for pp_rank in range(self.parallel_config.pipeline_parallel_size):
            self.pp_tp_workers.append([])
            for tp_rank in range(self.parallel_config.tensor_parallel_size):
                # PP=2, TP=4
                # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
                rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank
                assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                assert pp_rank < len(self.pp_tp_workers)
                self.pp_tp_workers[pp_rank].append(self.workers[rank])

    def reinitialize_distributed(
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
        self.collective_rpc("reinitialize_distributed", args=(reconfig_request,))
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
            self.shutdown()

    def execute_model(  # type: ignore[override]
409
410
411
412
413
414
415
416
417
        self,
        scheduler_output: SchedulerOutput,
        non_block: bool = False,
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
        if self.scheduler_output is not None:
            raise RuntimeError(
                "State error: sample_tokens() must be called "
                "after execute_model() returns None."
            )
418

419
        if not self.uses_sampler or not scheduler_output.total_num_scheduled_tokens:
420
421
422
423
            # Model will not execute, call model runner immediately.
            return self._execute_dag(scheduler_output, None, non_block)

        # Model will execute, defer to sample_tokens() call.
424
425
426
427
428
429
430
        self.scheduler_output = scheduler_output
        return COMPLETED_NONE_FUTURE if non_block else None

    def sample_tokens(  # type: ignore[override]
        self,
        grammar_output: "GrammarOutput | None",
        non_block: bool = False,
431
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
432
        """Execute the model on the Ray workers.
433

434
435
436
        The scheduler output to use should have been provided in
        a prior call to execute_model().

437
        Args:
438
            grammar_output: The structured outputs grammar bitmask, if applicable.
439
            non_block: If True, the method will return a Future.
440

441
442
443
        Returns:
            The model runner output.
        """
444
445
        scheduler_output = self.scheduler_output
        if scheduler_output is None:
446
            return COMPLETED_NONE_FUTURE if non_block else None
447
448
449

        self.scheduler_output = None

450
451
452
453
454
455
456
        return self._execute_dag(scheduler_output, grammar_output, non_block)

    def _execute_dag(
        self,
        scheduler_output: SchedulerOutput,
        grammar_output: "GrammarOutput | None",
        non_block: bool = False,
457
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
458
459
460
        # Build the compiled DAG for the first time.
        if self.forward_dag is None:  # type: ignore
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
461

462
        refs = self.forward_dag.execute((scheduler_output, grammar_output))  # type: ignore
463

464
465
466
467
468
        if not self.has_connector:
            # Get output only from a single worker (output_rank)
            # When PP is not used, we block here until the result is available.
            if not non_block:
                return refs[0].get()
469

470
471
            # When PP is used, we return a FutureWrapper immediately so that
            # the scheduler can yield to the next batch.
472
            return FutureWrapper(refs[0])
473

474
475
476
477
        # Get output from all workers when connector is present
        assert self.kv_output_aggregator is not None
        if not non_block:
            # Block and get results from all workers
478
            return self.kv_output_aggregator.aggregate(ray.get(refs))
479

480
481
        # Return a future that will aggregate outputs from all workers
        return FutureWrapper(refs, self.kv_output_aggregator)
482

483
    def collective_rpc(  # type: ignore[override]
484
        self,
485
        method: str | Callable,
486
487
488
489
        timeout: float | None = None,
        args: tuple = (),
        kwargs: dict[str, Any] | None = None,
        non_block: bool = False,
490
    ) -> list[Any] | Future[list[Any]]:
491
        """Runs the given method on all workers."""
492
        sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
493
        del method
494

495
496
        if kwargs is None:
            kwargs = {}
497
        ray_worker_outputs = [
498
499
500
            worker.execute_method.remote(  # type: ignore[attr-defined]
                sent_method, *args, **kwargs
            )
501
            for worker in self.workers
502
        ]
503
504

        # Get the results of the ray workers.
505
        if non_block:
506
            return FutureWrapper(ray_worker_outputs)
507

508
        return ray.get(ray_worker_outputs, timeout=timeout)
509

510
    def _check_ray_cgraph_installation(self):
511
512
        import importlib.metadata

513
514
        from packaging import version

Rui Qiao's avatar
Rui Qiao committed
515
        required_version = version.parse("2.43.0")
516
        current_version = version.parse(importlib.metadata.version("ray"))
517
        if current_version < required_version:
518
519
520
521
            raise ValueError(
                f"Ray version {required_version} is "
                f"required, but found {current_version}"
            )
522

523
        import importlib.util
524
525

        cgraph_spec = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
526
        if cgraph_spec is None:
527
528
529
530
            raise ValueError(
                "Ray Compiled Graph is not installed. "
                "Run `pip install ray[cgraph]` to install it."
            )
531
532

        cupy_spec = importlib.util.find_spec("cupy")
533
        if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl":
534
535
            raise ValueError(
                "cupy is not installed but required since "
536
                "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
537
538
                "Run `pip install ray[cgraph]` and check cupy installation."
            )
539
540

    def _compiled_ray_dag(self, enable_asyncio: bool):
541
        assert self.parallel_config.use_ray
542
        self._check_ray_cgraph_installation()
543
544
545
546
547
548
549
550
        # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
        # (it is 10 seconds by default). This is a Ray environment variable to
        # control the timeout of getting result from a compiled graph execution,
        # i.e., the distributed execution that includes model forward runs and
        # intermediate tensor communications, in the case of vllm.
        # Note: we should set this env var before importing
        # ray.dag, otherwise it will not take effect.
        os.environ.setdefault("RAY_CGRAPH_get_timeout", "300")  # noqa: SIM112
551
        from ray.dag import InputNode, MultiOutputNode
552
553

        logger.info(
554
555
556
            "RAY_CGRAPH_get_timeout is set to %s",
            os.environ["RAY_CGRAPH_get_timeout"],  # noqa: SIM112
        )
557
558
559
560
561
562
563
564
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE,
        )
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
565
566
567
568
569

        channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
        if channel_type not in ("auto", "nccl", "shm"):
            raise ValueError(
                "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
570
571
                f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'."
            )
572

573
        with InputNode() as input_data:
574
            # Example DAG: PP=2, TP=4
575
576
577
578
579
            #
            # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput   # noqa: E501
580
581
582
583
584
585
586

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
587
588
589
590
                outputs = [
                    worker.execute_model_ray.bind(outputs[i])  # type: ignore[attr-defined]
                    for i, worker in enumerate(tp_group)
                ]
591
592

                last_pp_rank = len(self.pp_tp_workers) - 1
593
594
595
596
                if (
                    pp_rank < last_pp_rank
                    and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"
                ):
597
598
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
599
600
                    # pp stage or when using shared memory (the default).
                    transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
601
                    outputs = [
Rui Qiao's avatar
Rui Qiao committed
602
                        output.with_tensor_transport(transport=transport)
603
604
605
606
607
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

608
609
        if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
            from ray.experimental.channel.accelerator_context import (
610
611
                register_accelerator_context,
            )
612
613

            from vllm.distributed.device_communicators.ray_communicator import (
614
615
616
617
618
619
620
621
622
623
624
                RayPPCommunicator,
            )

            register_accelerator_context(
                torch_module_name="cuda", communicator_cls=RayPPCommunicator
            )
            logger.info(
                "Using RayPPCommunicator "
                "(which wraps vLLM _PP GroupCoordinator) "
                "for Ray Compiled Graph communication."
            )
625
        else:
626
627
628
            logger.info(
                "Using Ray's NCCL communicator for Ray Compiled Graph communication."
            )
629

630
631
        return forward_dag.experimental_compile(
            enable_asyncio=enable_asyncio,
632
633
            _overlap_gpu_communication=envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
634
635

    def __del__(self):
636
        self.shutdown()
637

638
639
640
641
    def check_health(self) -> None:
        # Assume that the Ray workers are healthy.
        # TODO: check the health of the Ray workers
        return