ray_executor.py 24.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from collections import defaultdict
6
from collections.abc import Callable
7
from concurrent.futures import Future
8
from dataclasses import dataclass
9
from typing import TYPE_CHECKING, Any
10

11
import cloudpickle
12

13
import vllm.envs as envs
14
from vllm.logger import init_logger
15
from vllm.platforms import current_platform
16
from vllm.ray.ray_env import get_env_vars_to_copy
17
from vllm.utils.network_utils import (
18
19
20
21
    get_distributed_init_method,
    get_ip,
    get_open_port,
)
22
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
23
24
25
26
27
28
29
30
31
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (
    FutureWrapper,
    RayWorkerWrapper,
    initialize_ray_cluster,
    ray,
)
from vllm.v1.outputs import ModelRunnerOutput
32
33

if ray is not None:
34
    from ray.actor import ActorHandle
35
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
36
37
else:
    ActorHandle = None
38
39
40
41
42
43

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)

44
45
46
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
COMPLETED_NONE_FUTURE.set_result(None)

47

48
49
50
51
52
53
54
@dataclass
class RayWorkerMetaData:
    """
    Metadata for a Ray worker.
    The order of ray worker creation can be random,
    and we need to reset the rank after creating all workers.
    """
55

56
57
58
59
60
61
    worker: ActorHandle
    created_rank: int
    adjusted_rank: int = -1
    ip: str = ""


62
class RayDistributedExecutor(Executor):
63
64
65
66
67
    """Ray-based distributed executor"""

    # These env vars are worker-specific, therefore are NOT copied
    # from the driver to the workers
    WORKER_SPECIFIC_ENV_VARS = {
68
69
70
71
        "VLLM_HOST_IP",
        "VLLM_HOST_PORT",
        "LOCAL_RANK",
        "CUDA_VISIBLE_DEVICES",
72
73
    }

74
75
76
    # These non-vLLM env vars are copied from the driver to workers
    ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}

77
    uses_ray: bool = True
78
    supports_pp: bool = True
79

80
    def _init_executor(self) -> None:
81
        self.forward_dag: ray.dag.CompiledDAG | None = None
82
83
84
85

        # For TPU or XPU, avoid compiling NVIDIA's NCCL
        if current_platform.is_tpu() or current_platform.is_xpu():
            os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
86

87
        assert self.uses_ray
88
        initialize_ray_cluster(self.parallel_config)
89
90
91
92
93
94
95
96
97
98
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

99
100
        # KV connector setup
        self.has_connector = self.vllm_config.kv_transfer_config is not None
101

102
103
104
        self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
            self.vllm_config.ec_transfer_config is None
            or not self.vllm_config.ec_transfer_config.is_ec_producer
105
106
        )

107
108
        self.scheduler_output: SchedulerOutput | None = None

109
110
111
112
113
114
115
116
    @property
    def max_concurrent_batches(self) -> int:
        """Ray distributed executor supports pipeline parallelism,
        meaning that it allows PP size batches to be executed concurrently.
        """
        if self.scheduler_config.async_scheduling:
            return 2
        return self.parallel_config.pipeline_parallel_size
117

118
    def shutdown(self) -> None:
119
120
121
122
123
        if logger:
            # Somehow logger can be None here.
            logger.info(
                "Shutting down Ray distributed executor. If you see error log "
                "from logging.cc regarding SIGTERM received, please ignore "
124
125
                "because this is the expected termination process in Ray."
            )
126
127
128
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
129

130
131
132
133
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

134
    def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]:
135
136
137
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
138
139
140
141
142
143
144
        runtime_env.update(
            {
                "nsight": {
                    "t": "cuda,cudnn,cublas",
                    "o": "'worker_process_%p'",
                    "cuda-graph-trace": "node",
                }
145
            }
146
        )
147
148
149

        return ray_remote_kwargs

150
151
152
153
    # child class could overwrite this to return actual env vars.
    def _get_env_vars_to_be_updated(self):
        return self._env_vars_for_all_workers

154
    def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
155
        num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
156
157
158

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
159
        self.driver_dummy_worker: RayWorkerWrapper | None = None
160
        # The remaining workers are the actual ray actors.
161
        self.workers: list[RayWorkerWrapper] = []
162

163
164
165
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
166
        self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
167

168
169
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
170
171
                ray_remote_kwargs
            )
172

173
        # Create the workers.
174
        bundle_indices: list[int]
175
176
        if envs.VLLM_RAY_BUNDLE_INDICES:
            # Use the bundle indices specified by the user.
177
178
179
180
181
182
183
184
185
186
            bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
            assert len(bundle_indices) == self.parallel_config.world_size, (
                "VLLM_RAY_BUNDLE_INDICES must have the same size"
                f" as the world size, but got {bundle_indices=} "
                f"and {self.parallel_config.world_size=}"
            )
            assert len(set(bundle_indices)) == len(bundle_indices), (
                "VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
                f" but got {bundle_indices=}"
            )
187
188
189
190
191
192
        else:
            # use the first N bundles that have GPU resources.
            bundle_indices = []
            for bundle_id, bundle in enumerate(placement_group.bundle_specs):
                if bundle.get(current_platform.ray_device_key, 0):
                    bundle_indices.append(bundle_id)
193
            bundle_indices = bundle_indices[: self.parallel_config.world_size]
194

195
        worker_metadata: list[RayWorkerMetaData] = []
196
197
        driver_ip = get_ip()
        for rank, bundle_id in enumerate(bundle_indices):
198
199
200
201
202
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
203

204
205
206
207
208
209
210
            if current_platform.ray_device_key == "GPU":
                # NV+AMD GPUs, and Intel XPUs
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
211
212
213
                )(RayWorkerWrapper).remote(  # type: ignore[attr-defined]
                    vllm_config=self.vllm_config, rpc_rank=rank
                )
214
215
216
217
218
219
220
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={current_platform.ray_device_key: num_gpus},
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
221
222
223
                )(RayWorkerWrapper).remote(  # type: ignore[attr-defined]
                    vllm_config=self.vllm_config, rpc_rank=rank
                )
224
            worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
225

226
227
228
229
230
231
        worker_ips = ray.get(
            [
                each.worker.get_node_ip.remote()  # type: ignore[attr-defined]
                for each in worker_metadata
            ]
        )
232
233
234

        for each, ip in zip(worker_metadata, worker_ips):
            each.ip = ip
235

236
        logger.debug("workers: %s", worker_metadata)
237
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
238

239
        ip_counts: dict[str, int] = {}
240
241
242
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

243
        def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
244
245
246
247
248
249
250
251
252
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
253
            ip = item.ip
254
            return 0 if ip == driver_ip else 1, ip_counts[ip], ip
255
256
257
258

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
259
260
261
        sorted_worker_metadata = sorted(
            worker_metadata, key=sort_by_driver_then_worker_ip
        )
262
        for i, item in enumerate(sorted_worker_metadata):
263
            item.adjusted_rank = i
264
265
        self.workers = [item.worker for item in sorted_worker_metadata]
        rerank_mapping = {
266
            item.created_rank: item.adjusted_rank for item in sorted_worker_metadata
267
        }
268
        self.collective_rpc("adjust_rank", args=(rerank_mapping,))
269

270
        # Get the set of GPU IDs used on each node.
271
272
273
274
275
276
        worker_node_and_gpu_ids = []
        for worker in [self.driver_dummy_worker] + self.workers:
            if worker is None:
                # driver_dummy_worker can be None when using ray spmd worker.
                continue
            worker_node_and_gpu_ids.append(
277
                ray.get(worker.get_node_and_gpu_ids.remote())
278
            )  # type: ignore[attr-defined]
279

280
281
282
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

283
284
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
285
286
287
288
289
290
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
291
292
293
294
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

295
296
297
298
299
300
301
302
303
        all_ips = set(worker_ips + [driver_ip])
        n_ips = len(all_ips)
        n_nodes = len(node_workers)

        if n_nodes != n_ips:
            raise RuntimeError(
                f"Every node should have a unique IP address. Got {n_nodes}"
                f" nodes with node ids {list(node_workers.keys())} and "
                f"{n_ips} unique IP addresses {all_ips}. Please check your"
304
305
                " network configuration. If you set `VLLM_HOST_IP`"
                " environment variable, make sure it is unique for"
306
307
                " each node."
            )
308

309
        # Set environment variables for the driver and workers.
310
311
312
313
314
315
316
317
        all_args_to_update_environment_variables = [
            {
                current_platform.device_control_env_var: ",".join(
                    map(str, node_gpus[node_id])
                ),
            }
            for (node_id, _) in worker_node_and_gpu_ids
        ]
318

319
        # Environment variables to copy from driver to workers
320
321
        env_vars_to_copy = get_env_vars_to_copy(
            exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
322
            additional_vars=set(current_platform.additional_env_vars).union(
323
324
325
326
                self.ADDITIONAL_ENV_VARS
            ),
            destination="workers",
        )
327

328
        # Copy existing env vars to each worker's args
329
330
        for args in all_args_to_update_environment_variables:
            # TODO: refactor platform-specific env vars
331
            for name in env_vars_to_copy:
332
333
                if name in os.environ:
                    args[name] = os.environ[name]
334

335
        self._env_vars_for_all_workers = all_args_to_update_environment_variables
336

337
338
        self.collective_rpc(
            "update_environment_variables", args=(self._get_env_vars_to_be_updated(),)
339
        )
340

341
342
343
344
345
346
347
348
349
350
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
351
        distributed_init_method = get_distributed_init_method(
352
353
            driver_ip, get_open_port()
        )
354

355
        # Initialize the actual workers inside worker wrapper.
356
357
358
359
360
361
        all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
            local_rank = node_workers[node_id].index(rank)
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
362
363
                rank=rank,
                distributed_init_method=distributed_init_method,
364
365
366
367
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        self.collective_rpc("init_worker", args=(all_kwargs,))

        self.collective_rpc("init_device")
        self.collective_rpc("load_model")

        for pp_rank in range(self.parallel_config.pipeline_parallel_size):
            self.pp_tp_workers.append([])
            for tp_rank in range(self.parallel_config.tensor_parallel_size):
                # PP=2, TP=4
                # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
                rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank
                assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                assert pp_rank < len(self.pp_tp_workers)
                self.pp_tp_workers[pp_rank].append(self.workers[rank])

    def reinitialize_distributed(
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
        self.collective_rpc("reinitialize_distributed", args=(reconfig_request,))
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
            self.shutdown()

    def execute_model(  # type: ignore[override]
394
395
396
397
398
399
400
401
402
        self,
        scheduler_output: SchedulerOutput,
        non_block: bool = False,
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
        if self.scheduler_output is not None:
            raise RuntimeError(
                "State error: sample_tokens() must be called "
                "after execute_model() returns None."
            )
403

404
        if not self.uses_sampler or not scheduler_output.total_num_scheduled_tokens:
405
406
407
408
            # Model will not execute, call model runner immediately.
            return self._execute_dag(scheduler_output, None, non_block)

        # Model will execute, defer to sample_tokens() call.
409
410
411
412
413
414
415
        self.scheduler_output = scheduler_output
        return COMPLETED_NONE_FUTURE if non_block else None

    def sample_tokens(  # type: ignore[override]
        self,
        grammar_output: "GrammarOutput | None",
        non_block: bool = False,
416
417
    ) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
        """Execute the model on the Ray workers.
418

419
420
421
        The scheduler output to use should have been provided in
        a prior call to execute_model().

422
        Args:
423
            grammar_output: The structured outputs grammar bitmask, if applicable.
424
            non_block: If True, the method will return a Future.
425

426
427
428
        Returns:
            The model runner output.
        """
429
430
        scheduler_output = self.scheduler_output
        if scheduler_output is None:
431
            return COMPLETED_NONE_FUTURE if non_block else None  # noqa
432
433
434

        self.scheduler_output = None

435
436
437
438
439
440
441
442
        return self._execute_dag(scheduler_output, grammar_output, non_block)

    def _execute_dag(
        self,
        scheduler_output: SchedulerOutput,
        grammar_output: "GrammarOutput | None",
        non_block: bool = False,
    ) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
443
444
445
        # Build the compiled DAG for the first time.
        if self.forward_dag is None:  # type: ignore
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
446

447
        refs = self.forward_dag.execute((scheduler_output, grammar_output))  # type: ignore
448

449
450
451
452
453
        if not self.has_connector:
            # Get output only from a single worker (output_rank)
            # When PP is not used, we block here until the result is available.
            if not non_block:
                return refs[0].get()
454

455
456
            # When PP is used, we return a FutureWrapper immediately so that
            # the scheduler can yield to the next batch.
457
            return FutureWrapper(refs[0])
458

459
460
461
462
        # Get output from all workers when connector is present
        assert self.kv_output_aggregator is not None
        if not non_block:
            # Block and get results from all workers
463
            return self.kv_output_aggregator.aggregate(ray.get(refs))
464

465
466
        # Return a future that will aggregate outputs from all workers
        return FutureWrapper(refs, self.kv_output_aggregator)
467

468
    def collective_rpc(  # type: ignore[override]
469
        self,
470
        method: str | Callable,
471
472
473
474
        timeout: float | None = None,
        args: tuple = (),
        kwargs: dict[str, Any] | None = None,
        non_block: bool = False,
475
    ) -> list[Any] | Future[list[Any]]:
476
        """Runs the given method on all workers."""
477
        sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
478
        del method
479

480
481
        if kwargs is None:
            kwargs = {}
482
        ray_worker_outputs = [
483
484
485
            worker.execute_method.remote(  # type: ignore[attr-defined]
                sent_method, *args, **kwargs
            )
486
            for worker in self.workers
487
        ]
488
489

        # Get the results of the ray workers.
490
        if non_block:
491
            return FutureWrapper(ray_worker_outputs)
492

493
        return ray.get(ray_worker_outputs, timeout=timeout)
494

495
    def _check_ray_cgraph_installation(self):
496
497
        import importlib.metadata

498
499
        from packaging import version

Rui Qiao's avatar
Rui Qiao committed
500
        required_version = version.parse("2.43.0")
501
        current_version = version.parse(importlib.metadata.version("ray"))
502
        if current_version < required_version:
503
504
505
506
            raise ValueError(
                f"Ray version {required_version} is "
                f"required, but found {current_version}"
            )
507

508
        import importlib.util
509
510

        cgraph_spec = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
511
        if cgraph_spec is None:
512
513
514
515
            raise ValueError(
                "Ray Compiled Graph is not installed. "
                "Run `pip install ray[cgraph]` to install it."
            )
516
517

        cupy_spec = importlib.util.find_spec("cupy")
518
        if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl":
519
520
            raise ValueError(
                "cupy is not installed but required since "
521
                "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
522
523
                "Run `pip install ray[cgraph]` and check cupy installation."
            )
524
525

    def _compiled_ray_dag(self, enable_asyncio: bool):
526
        assert self.parallel_config.use_ray
527
        self._check_ray_cgraph_installation()
528
529
530
531
532
533
534
535
        # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
        # (it is 10 seconds by default). This is a Ray environment variable to
        # control the timeout of getting result from a compiled graph execution,
        # i.e., the distributed execution that includes model forward runs and
        # intermediate tensor communications, in the case of vllm.
        # Note: we should set this env var before importing
        # ray.dag, otherwise it will not take effect.
        os.environ.setdefault("RAY_CGRAPH_get_timeout", "300")  # noqa: SIM112
536
        from ray.dag import InputNode, MultiOutputNode
537
538

        logger.info(
539
540
541
            "RAY_CGRAPH_get_timeout is set to %s",
            os.environ["RAY_CGRAPH_get_timeout"],  # noqa: SIM112
        )
542
543
544
545
546
547
548
549
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE,
        )
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
550
551
552
553
554

        channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
        if channel_type not in ("auto", "nccl", "shm"):
            raise ValueError(
                "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
555
556
                f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'."
            )
557

558
        with InputNode() as input_data:
559
            # Example DAG: PP=2, TP=4
560
561
562
563
564
            #
            # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput   # noqa: E501
565
566
567
568
569
570
571

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
572
573
574
575
                outputs = [
                    worker.execute_model_ray.bind(outputs[i])  # type: ignore[attr-defined]
                    for i, worker in enumerate(tp_group)
                ]
576
577

                last_pp_rank = len(self.pp_tp_workers) - 1
578
579
580
581
                if (
                    pp_rank < last_pp_rank
                    and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"
                ):
582
583
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
584
585
                    # pp stage or when using shared memory (the default).
                    transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
586
                    outputs = [
Rui Qiao's avatar
Rui Qiao committed
587
                        output.with_tensor_transport(transport=transport)
588
589
590
591
592
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

593
594
        if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
            from ray.experimental.channel.accelerator_context import (
595
596
                register_accelerator_context,
            )
597
598

            from vllm.distributed.device_communicators.ray_communicator import (
599
600
601
602
603
604
605
606
607
608
609
                RayPPCommunicator,
            )

            register_accelerator_context(
                torch_module_name="cuda", communicator_cls=RayPPCommunicator
            )
            logger.info(
                "Using RayPPCommunicator "
                "(which wraps vLLM _PP GroupCoordinator) "
                "for Ray Compiled Graph communication."
            )
610
        else:
611
612
613
            logger.info(
                "Using Ray's NCCL communicator for Ray Compiled Graph communication."
            )
614

615
616
        return forward_dag.experimental_compile(
            enable_asyncio=enable_asyncio,
617
618
            _overlap_gpu_communication=envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
619
620

    def __del__(self):
621
        self.shutdown()
622

623
624
625
626
    def check_health(self) -> None:
        # Assume that the Ray workers are healthy.
        # TODO: check the health of the Ray workers
        return