ray_executor.py 24.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from collections import defaultdict
6
from collections.abc import Callable
7
from concurrent.futures import Future
8
from dataclasses import dataclass
9
from typing import TYPE_CHECKING, Any
10

11
import cloudpickle
12

13
import vllm.envs as envs
14
from vllm.logger import init_logger
15
from vllm.platforms import current_platform
16
from vllm.ray.ray_env import get_env_vars_to_copy
17
from vllm.utils.network_utils import (
18
19
20
21
    get_distributed_init_method,
    get_ip,
    get_open_port,
)
22
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
23
24
25
26
27
28
29
30
31
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (
    FutureWrapper,
    RayWorkerWrapper,
    initialize_ray_cluster,
    ray,
)
from vllm.v1.outputs import ModelRunnerOutput
32
33

if ray is not None:
34
    from ray.actor import ActorHandle
35
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
36
37
else:
    ActorHandle = None
38
39
40
41
42
43

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)

44
45
46
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
COMPLETED_NONE_FUTURE.set_result(None)

47

48
49
50
51
52
53
54
@dataclass
class RayWorkerMetaData:
    """
    Metadata for a Ray worker.
    The order of ray worker creation can be random,
    and we need to reset the rank after creating all workers.
    """
55

56
57
58
59
60
61
    worker: ActorHandle
    created_rank: int
    adjusted_rank: int = -1
    ip: str = ""


62
class RayDistributedExecutor(Executor):
63
64
65
66
67
    """Ray-based distributed executor"""

    # These env vars are worker-specific, therefore are NOT copied
    # from the driver to the workers
    WORKER_SPECIFIC_ENV_VARS = {
68
69
70
71
        "VLLM_HOST_IP",
        "VLLM_HOST_PORT",
        "LOCAL_RANK",
        "CUDA_VISIBLE_DEVICES",
72
73
    }

74
75
76
    # These non-vLLM env vars are copied from the driver to workers
    ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}

77
    uses_ray: bool = True
78
    supports_pp: bool = True
79

80
    def _init_executor(self) -> None:
81
        self.forward_dag: ray.dag.CompiledDAG | None = None
82
83
84
85

        # For TPU or XPU, avoid compiling NVIDIA's NCCL
        if current_platform.is_tpu() or current_platform.is_xpu():
            os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
86

87
        assert self.uses_ray
88
        initialize_ray_cluster(self.parallel_config)
89
90
91
92
93
94
95
96
97
98
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

99
100
        # KV connector setup
        self.has_connector = self.vllm_config.kv_transfer_config is not None
101

102
103
104
        self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
            self.vllm_config.ec_transfer_config is None
            or not self.vllm_config.ec_transfer_config.is_ec_producer
105
106
        )

107
108
        self.scheduler_output: SchedulerOutput | None = None

109
110
111
112
113
    @property
    def max_concurrent_batches(self) -> int:
        """Ray distributed executor supports pipeline parallelism,
        meaning that it allows PP size batches to be executed concurrently.
        """
114
115
        pp_size = self.parallel_config.pipeline_parallel_size
        return 2 if pp_size <= 1 and self.scheduler_config.async_scheduling else pp_size
116

117
    def shutdown(self) -> None:
118
119
120
121
122
        if logger:
            # Somehow logger can be None here.
            logger.info(
                "Shutting down Ray distributed executor. If you see error log "
                "from logging.cc regarding SIGTERM received, please ignore "
123
124
                "because this is the expected termination process in Ray."
            )
125
126
127
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
128

129
130
131
132
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

133
    def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]:
134
135
136
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
137
138
139
140
141
142
143
        runtime_env.update(
            {
                "nsight": {
                    "t": "cuda,cudnn,cublas",
                    "o": "'worker_process_%p'",
                    "cuda-graph-trace": "node",
                }
144
            }
145
        )
146
147
148

        return ray_remote_kwargs

149
150
151
152
    # child class could overwrite this to return actual env vars.
    def _get_env_vars_to_be_updated(self):
        return self._env_vars_for_all_workers

153
    def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
154
        num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
155
156
157

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
158
        self.driver_dummy_worker: RayWorkerWrapper | None = None
159
        # The remaining workers are the actual ray actors.
160
        self.workers: list[RayWorkerWrapper] = []
161

162
163
164
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
165
        self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
166

167
168
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
169
170
                ray_remote_kwargs
            )
171

172
        # Create the workers.
173
        bundle_indices: list[int]
174
175
        if envs.VLLM_RAY_BUNDLE_INDICES:
            # Use the bundle indices specified by the user.
176
177
178
179
180
181
182
183
184
185
            bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
            assert len(bundle_indices) == self.parallel_config.world_size, (
                "VLLM_RAY_BUNDLE_INDICES must have the same size"
                f" as the world size, but got {bundle_indices=} "
                f"and {self.parallel_config.world_size=}"
            )
            assert len(set(bundle_indices)) == len(bundle_indices), (
                "VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
                f" but got {bundle_indices=}"
            )
186
187
188
189
190
191
        else:
            # use the first N bundles that have GPU resources.
            bundle_indices = []
            for bundle_id, bundle in enumerate(placement_group.bundle_specs):
                if bundle.get(current_platform.ray_device_key, 0):
                    bundle_indices.append(bundle_id)
192
            bundle_indices = bundle_indices[: self.parallel_config.world_size]
193

194
        worker_metadata: list[RayWorkerMetaData] = []
195
196
        driver_ip = get_ip()
        for rank, bundle_id in enumerate(bundle_indices):
197
198
199
200
201
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
202

203
204
205
206
207
208
209
            if current_platform.ray_device_key == "GPU":
                # NV+AMD GPUs, and Intel XPUs
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
210
                )(RayWorkerWrapper).remote(rpc_rank=rank)
211
212
213
214
215
216
217
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={current_platform.ray_device_key: num_gpus},
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
218
219
                )(RayWorkerWrapper).remote(rpc_rank=rank)

220
            worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
221

222
223
224
225
226
227
        worker_ips = ray.get(
            [
                each.worker.get_node_ip.remote()  # type: ignore[attr-defined]
                for each in worker_metadata
            ]
        )
228
229
230

        for each, ip in zip(worker_metadata, worker_ips):
            each.ip = ip
231

232
        logger.debug("workers: %s", worker_metadata)
233
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
234

235
        ip_counts: dict[str, int] = {}
236
237
238
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

239
        def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
240
241
242
243
244
245
246
247
248
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
249
            ip = item.ip
250
            return 0 if ip == driver_ip else 1, ip_counts[ip], ip
251
252
253
254

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
255
256
257
        sorted_worker_metadata = sorted(
            worker_metadata, key=sort_by_driver_then_worker_ip
        )
258
        for i, item in enumerate(sorted_worker_metadata):
259
            item.adjusted_rank = i
260
261
        self.workers = [item.worker for item in sorted_worker_metadata]
        rerank_mapping = {
262
            item.created_rank: item.adjusted_rank for item in sorted_worker_metadata
263
        }
264
        self.collective_rpc("adjust_rank", args=(rerank_mapping,))
265

266
        # Get the set of GPU IDs used on each node.
267
268
269
270
271
272
        worker_node_and_gpu_ids = []
        for worker in [self.driver_dummy_worker] + self.workers:
            if worker is None:
                # driver_dummy_worker can be None when using ray spmd worker.
                continue
            worker_node_and_gpu_ids.append(
273
                ray.get(worker.get_node_and_gpu_ids.remote())
274
            )  # type: ignore[attr-defined]
275

276
277
278
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

279
280
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
281
282
283
284
285
286
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
287
288
289
290
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

291
292
293
294
295
296
297
298
299
        all_ips = set(worker_ips + [driver_ip])
        n_ips = len(all_ips)
        n_nodes = len(node_workers)

        if n_nodes != n_ips:
            raise RuntimeError(
                f"Every node should have a unique IP address. Got {n_nodes}"
                f" nodes with node ids {list(node_workers.keys())} and "
                f"{n_ips} unique IP addresses {all_ips}. Please check your"
300
301
                " network configuration. If you set `VLLM_HOST_IP`"
                " environment variable, make sure it is unique for"
302
303
                " each node."
            )
304

305
        # Set environment variables for the driver and workers.
306
307
308
309
310
311
312
313
        all_args_to_update_environment_variables = [
            {
                current_platform.device_control_env_var: ",".join(
                    map(str, node_gpus[node_id])
                ),
            }
            for (node_id, _) in worker_node_and_gpu_ids
        ]
314

315
        # Environment variables to copy from driver to workers
316
317
        env_vars_to_copy = get_env_vars_to_copy(
            exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
318
            additional_vars=set(current_platform.additional_env_vars).union(
319
320
321
322
                self.ADDITIONAL_ENV_VARS
            ),
            destination="workers",
        )
323

324
        # Copy existing env vars to each worker's args
325
326
        for args in all_args_to_update_environment_variables:
            # TODO: refactor platform-specific env vars
327
            for name in env_vars_to_copy:
328
329
                if name in os.environ:
                    args[name] = os.environ[name]
330

331
        self._env_vars_for_all_workers = all_args_to_update_environment_variables
332

333
334
        self.collective_rpc(
            "update_environment_variables", args=(self._get_env_vars_to_be_updated(),)
335
        )
336

337
338
339
340
341
342
343
344
345
346
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
347
        distributed_init_method = get_distributed_init_method(
348
349
            driver_ip, get_open_port()
        )
350

351
        # Initialize the actual workers inside worker wrapper.
352
353
354
355
356
357
        all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
            local_rank = node_workers[node_id].index(rank)
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
358
359
                rank=rank,
                distributed_init_method=distributed_init_method,
360
361
362
363
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        self.collective_rpc("init_worker", args=(all_kwargs,))

        self.collective_rpc("init_device")
        self.collective_rpc("load_model")

        for pp_rank in range(self.parallel_config.pipeline_parallel_size):
            self.pp_tp_workers.append([])
            for tp_rank in range(self.parallel_config.tensor_parallel_size):
                # PP=2, TP=4
                # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
                rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank
                assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                assert pp_rank < len(self.pp_tp_workers)
                self.pp_tp_workers[pp_rank].append(self.workers[rank])

    def reinitialize_distributed(
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
        self.collective_rpc("reinitialize_distributed", args=(reconfig_request,))
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
            self.shutdown()

    def execute_model(  # type: ignore[override]
390
391
392
393
394
395
396
397
398
        self,
        scheduler_output: SchedulerOutput,
        non_block: bool = False,
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
        if self.scheduler_output is not None:
            raise RuntimeError(
                "State error: sample_tokens() must be called "
                "after execute_model() returns None."
            )
399

400
        if not self.uses_sampler or not scheduler_output.total_num_scheduled_tokens:
401
402
403
404
            # Model will not execute, call model runner immediately.
            return self._execute_dag(scheduler_output, None, non_block)

        # Model will execute, defer to sample_tokens() call.
405
406
407
408
409
410
411
        self.scheduler_output = scheduler_output
        return COMPLETED_NONE_FUTURE if non_block else None

    def sample_tokens(  # type: ignore[override]
        self,
        grammar_output: "GrammarOutput | None",
        non_block: bool = False,
412
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
413
        """Execute the model on the Ray workers.
414

415
416
417
        The scheduler output to use should have been provided in
        a prior call to execute_model().

418
        Args:
419
            grammar_output: The structured outputs grammar bitmask, if applicable.
420
            non_block: If True, the method will return a Future.
421

422
423
424
        Returns:
            The model runner output.
        """
425
426
        scheduler_output = self.scheduler_output
        if scheduler_output is None:
427
            return COMPLETED_NONE_FUTURE if non_block else None
428
429
430

        self.scheduler_output = None

431
432
433
434
435
436
437
        return self._execute_dag(scheduler_output, grammar_output, non_block)

    def _execute_dag(
        self,
        scheduler_output: SchedulerOutput,
        grammar_output: "GrammarOutput | None",
        non_block: bool = False,
438
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
439
440
441
        # Build the compiled DAG for the first time.
        if self.forward_dag is None:  # type: ignore
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
442

443
        refs = self.forward_dag.execute((scheduler_output, grammar_output))  # type: ignore
444

445
446
447
448
449
        if not self.has_connector:
            # Get output only from a single worker (output_rank)
            # When PP is not used, we block here until the result is available.
            if not non_block:
                return refs[0].get()
450

451
452
            # When PP is used, we return a FutureWrapper immediately so that
            # the scheduler can yield to the next batch.
453
            return FutureWrapper(refs[0])
454

455
456
457
458
        # Get output from all workers when connector is present
        assert self.kv_output_aggregator is not None
        if not non_block:
            # Block and get results from all workers
459
            return self.kv_output_aggregator.aggregate(ray.get(refs))
460

461
462
        # Return a future that will aggregate outputs from all workers
        return FutureWrapper(refs, self.kv_output_aggregator)
463

464
    def collective_rpc(  # type: ignore[override]
465
        self,
466
        method: str | Callable,
467
468
469
470
        timeout: float | None = None,
        args: tuple = (),
        kwargs: dict[str, Any] | None = None,
        non_block: bool = False,
471
    ) -> list[Any] | Future[list[Any]]:
472
        """Runs the given method on all workers."""
473
        sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
474
        del method
475

476
477
        if kwargs is None:
            kwargs = {}
478
        ray_worker_outputs = [
479
480
481
            worker.execute_method.remote(  # type: ignore[attr-defined]
                sent_method, *args, **kwargs
            )
482
            for worker in self.workers
483
        ]
484
485

        # Get the results of the ray workers.
486
        if non_block:
487
            return FutureWrapper(ray_worker_outputs)
488

489
        return ray.get(ray_worker_outputs, timeout=timeout)
490

491
    def _check_ray_cgraph_installation(self):
492
493
        import importlib.metadata

494
495
        from packaging import version

Rui Qiao's avatar
Rui Qiao committed
496
        required_version = version.parse("2.43.0")
497
        current_version = version.parse(importlib.metadata.version("ray"))
498
        if current_version < required_version:
499
500
501
502
            raise ValueError(
                f"Ray version {required_version} is "
                f"required, but found {current_version}"
            )
503

504
        import importlib.util
505
506

        cgraph_spec = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
507
        if cgraph_spec is None:
508
509
510
511
            raise ValueError(
                "Ray Compiled Graph is not installed. "
                "Run `pip install ray[cgraph]` to install it."
            )
512
513

        cupy_spec = importlib.util.find_spec("cupy")
514
        if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl":
515
516
            raise ValueError(
                "cupy is not installed but required since "
517
                "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
518
519
                "Run `pip install ray[cgraph]` and check cupy installation."
            )
520
521

    def _compiled_ray_dag(self, enable_asyncio: bool):
522
        assert self.parallel_config.use_ray
523
        self._check_ray_cgraph_installation()
524
525
526
527
528
529
530
531
        # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
        # (it is 10 seconds by default). This is a Ray environment variable to
        # control the timeout of getting result from a compiled graph execution,
        # i.e., the distributed execution that includes model forward runs and
        # intermediate tensor communications, in the case of vllm.
        # Note: we should set this env var before importing
        # ray.dag, otherwise it will not take effect.
        os.environ.setdefault("RAY_CGRAPH_get_timeout", "300")  # noqa: SIM112
532
        from ray.dag import InputNode, MultiOutputNode
533
534

        logger.info(
535
536
537
            "RAY_CGRAPH_get_timeout is set to %s",
            os.environ["RAY_CGRAPH_get_timeout"],  # noqa: SIM112
        )
538
539
540
541
542
543
544
545
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE,
        )
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
546
547
548
549
550

        channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
        if channel_type not in ("auto", "nccl", "shm"):
            raise ValueError(
                "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
551
552
                f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'."
            )
553

554
        with InputNode() as input_data:
555
            # Example DAG: PP=2, TP=4
556
557
558
559
560
            #
            # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput   # noqa: E501
561
562
563
564
565
566
567

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
568
569
570
571
                outputs = [
                    worker.execute_model_ray.bind(outputs[i])  # type: ignore[attr-defined]
                    for i, worker in enumerate(tp_group)
                ]
572
573

                last_pp_rank = len(self.pp_tp_workers) - 1
574
575
576
577
                if (
                    pp_rank < last_pp_rank
                    and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"
                ):
578
579
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
580
581
                    # pp stage or when using shared memory (the default).
                    transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
582
                    outputs = [
Rui Qiao's avatar
Rui Qiao committed
583
                        output.with_tensor_transport(transport=transport)
584
585
586
587
588
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

589
590
        if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
            from ray.experimental.channel.accelerator_context import (
591
592
                register_accelerator_context,
            )
593
594

            from vllm.distributed.device_communicators.ray_communicator import (
595
596
597
598
599
600
601
602
603
604
605
                RayPPCommunicator,
            )

            register_accelerator_context(
                torch_module_name="cuda", communicator_cls=RayPPCommunicator
            )
            logger.info(
                "Using RayPPCommunicator "
                "(which wraps vLLM _PP GroupCoordinator) "
                "for Ray Compiled Graph communication."
            )
606
        else:
607
608
609
            logger.info(
                "Using Ray's NCCL communicator for Ray Compiled Graph communication."
            )
610

611
612
        return forward_dag.experimental_compile(
            enable_asyncio=enable_asyncio,
613
614
            _overlap_gpu_communication=envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
615
616

    def __del__(self):
617
        self.shutdown()
618

619
620
621
622
    def check_health(self) -> None:
        # Assume that the Ray workers are healthy.
        # TODO: check the health of the Ray workers
        return