ray_executor.py 25.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from collections import defaultdict
6
from collections.abc import Callable
7
from concurrent.futures import Future
8
from dataclasses import dataclass
9
from typing import TYPE_CHECKING, Any
10

11
import cloudpickle
12

13
import vllm.envs as envs
14
from vllm.logger import init_logger
15
from vllm.platforms import current_platform
16
from vllm.ray.ray_env import get_env_vars_to_copy
17
from vllm.utils.network_utils import (
18
19
20
21
    get_distributed_init_method,
    get_ip,
    get_open_port,
)
22
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
23
24
25
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (
26
    WORKER_SPECIFIC_ENV_VARS,
27
28
29
30
31
32
    FutureWrapper,
    RayWorkerWrapper,
    initialize_ray_cluster,
    ray,
)
from vllm.v1.outputs import ModelRunnerOutput
33
34

if ray is not None:
35
    from ray.actor import ActorHandle
36
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
37
38
else:
    ActorHandle = None
39
40
41
42
43
44

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)

45
46
47
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
COMPLETED_NONE_FUTURE.set_result(None)

48

49
50
51
52
53
54
55
@dataclass
class RayWorkerMetaData:
    """
    Metadata for a Ray worker.
    The order of ray worker creation can be random,
    and we need to reset the rank after creating all workers.
    """
56

57
58
59
60
61
62
    worker: ActorHandle
    created_rank: int
    adjusted_rank: int = -1
    ip: str = ""


63
class RayDistributedExecutor(Executor):
64
65
    """Ray-based distributed executor"""

66
    uses_ray: bool = True
67
    supports_pp: bool = True
68

69
    def _init_executor(self) -> None:
70
        self.forward_dag: ray.dag.CompiledDAG | None = None
71
72
73
74

        # For TPU or XPU, avoid compiling NVIDIA's NCCL
        if current_platform.is_tpu() or current_platform.is_xpu():
            os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
75

76
        assert self.uses_ray
77
        initialize_ray_cluster(self.parallel_config)
78
79
80
81
82
83
84
85
86
87
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

88
89
        # KV connector setup
        self.has_connector = self.vllm_config.kv_transfer_config is not None
90

91
92
        self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
            self.vllm_config.ec_transfer_config is None
93
            or self.vllm_config.ec_transfer_config.is_ec_consumer
94
95
        )

96
97
        self.scheduler_output: SchedulerOutput | None = None

98
99
100
101
102
    @property
    def max_concurrent_batches(self) -> int:
        """Ray distributed executor supports pipeline parallelism,
        meaning that it allows PP size batches to be executed concurrently.
        """
103
104
        pp_size = self.parallel_config.pipeline_parallel_size
        return 2 if pp_size <= 1 and self.scheduler_config.async_scheduling else pp_size
105

106
    def shutdown(self) -> None:
107
108
109
110
111
        if logger:
            # Somehow logger can be None here.
            logger.info(
                "Shutting down Ray distributed executor. If you see error log "
                "from logging.cc regarding SIGTERM received, please ignore "
112
113
                "because this is the expected termination process in Ray."
            )
114
115
116
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
117

118
119
120
121
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

122
    def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]:
123
124
125
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
126
127
128
129
130
131
132
        runtime_env.update(
            {
                "nsight": {
                    "t": "cuda,cudnn,cublas",
                    "o": "'worker_process_%p'",
                    "cuda-graph-trace": "node",
                }
133
            }
134
        )
135
136
137

        return ray_remote_kwargs

138
139
140
141
142
143
144
145
    def _update_noset_device_env_vars(self, ray_remote_kwargs):
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
        env_vars = runtime_env.setdefault("env_vars", {})
        env_vars.update(
            {env_var: "1" for env_var in current_platform.ray_noset_device_env_vars}
        )
        return ray_remote_kwargs

146
147
148
149
    # child class could overwrite this to return actual env vars.
    def _get_env_vars_to_be_updated(self):
        return self._env_vars_for_all_workers

150
    def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
151
        num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
152
153
154

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
155
        self.driver_dummy_worker: RayWorkerWrapper | None = None
156
        # The remaining workers are the actual ray actors.
157
        self.workers: list[RayWorkerWrapper] = []
158

159
160
161
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
162
        self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
163

164
165
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
166
167
                ray_remote_kwargs
            )
168

169
170
171
172
173
        # The way ray actors are setup in vllm is that the visible devices are
        # not set by actors, they are left unset by ray. Internally we index
        # the right gpu with local_rank. This is similar to how mp mode works.
        self._update_noset_device_env_vars(ray_remote_kwargs)

174
        # Create the workers.
175
        bundle_indices: list[int]
176
177
        if envs.VLLM_RAY_BUNDLE_INDICES:
            # Use the bundle indices specified by the user.
178
179
180
181
182
183
184
185
186
187
            bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
            assert len(bundle_indices) == self.parallel_config.world_size, (
                "VLLM_RAY_BUNDLE_INDICES must have the same size"
                f" as the world size, but got {bundle_indices=} "
                f"and {self.parallel_config.world_size=}"
            )
            assert len(set(bundle_indices)) == len(bundle_indices), (
                "VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
                f" but got {bundle_indices=}"
            )
188
189
190
191
192
193
        else:
            # use the first N bundles that have GPU resources.
            bundle_indices = []
            for bundle_id, bundle in enumerate(placement_group.bundle_specs):
                if bundle.get(current_platform.ray_device_key, 0):
                    bundle_indices.append(bundle_id)
194
            bundle_indices = bundle_indices[: self.parallel_config.world_size]
195

196
        worker_metadata: list[RayWorkerMetaData] = []
197
198
        driver_ip = get_ip()
        for rank, bundle_id in enumerate(bundle_indices):
199
200
201
202
203
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
204

205
206
207
208
209
210
211
            if current_platform.ray_device_key == "GPU":
                # NV+AMD GPUs, and Intel XPUs
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
212
                )(RayWorkerWrapper).remote(rpc_rank=rank)
213
214
215
216
217
218
219
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={current_platform.ray_device_key: num_gpus},
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
220
221
                )(RayWorkerWrapper).remote(rpc_rank=rank)

222
            worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
223

224
225
226
227
228
229
        worker_ips = ray.get(
            [
                each.worker.get_node_ip.remote()  # type: ignore[attr-defined]
                for each in worker_metadata
            ]
        )
230
231
232

        for each, ip in zip(worker_metadata, worker_ips):
            each.ip = ip
233

234
        logger.debug("workers: %s", worker_metadata)
235
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
236

237
        ip_counts: dict[str, int] = {}
238
239
240
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

241
        def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
242
243
244
245
246
247
248
249
250
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
251
            ip = item.ip
252
            return 0 if ip == driver_ip else 1, ip_counts[ip], ip
253
254
255
256

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
257
258
259
        sorted_worker_metadata = sorted(
            worker_metadata, key=sort_by_driver_then_worker_ip
        )
260
        for i, item in enumerate(sorted_worker_metadata):
261
            item.adjusted_rank = i
262
263
        self.workers = [item.worker for item in sorted_worker_metadata]
        rerank_mapping = {
264
            item.created_rank: item.adjusted_rank for item in sorted_worker_metadata
265
        }
266
        self.collective_rpc("adjust_rank", args=(rerank_mapping,))
267

268
        # Get the set of GPU IDs used on each node.
269
270
271
272
273
274
        worker_node_and_gpu_ids = []
        for worker in [self.driver_dummy_worker] + self.workers:
            if worker is None:
                # driver_dummy_worker can be None when using ray spmd worker.
                continue
            worker_node_and_gpu_ids.append(
275
276
                ray.get(worker.get_node_and_gpu_ids.remote())  # type: ignore[attr-defined]
            )
277

278
279
280
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

281
282
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
283
284
285
286
287
288
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
289
290
291
292
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

293
294
295
296
297
298
299
300
301
        all_ips = set(worker_ips + [driver_ip])
        n_ips = len(all_ips)
        n_nodes = len(node_workers)

        if n_nodes != n_ips:
            raise RuntimeError(
                f"Every node should have a unique IP address. Got {n_nodes}"
                f" nodes with node ids {list(node_workers.keys())} and "
                f"{n_ips} unique IP addresses {all_ips}. Please check your"
302
303
                " network configuration. If you set `VLLM_HOST_IP`"
                " environment variable, make sure it is unique for"
304
305
                " each node."
            )
306

307
        # Set environment variables for the driver and workers.
308
309
310
311
312
313
314
315
316
        # We set CUDA_VISIBLE_DEVICES to ALL GPUs on the node for each worker.
        # This is needed because:
        # 1. Ray's compiled DAG needs to find the allocated GPU in
        #    CUDA_VISIBLE_DEVICES.
        # 2. vLLM's communication layer (NCCL, CustomAllreduce) needs to see
        #    all GPUs for P2P checks and communication setup. Though if it was
        #    just this reason, we could have also just kept the visible devices
        #    unset.
        # Each worker will use local_rank to index into the visible devices.
317
318
319
320
321
322
323
324
        all_args_to_update_environment_variables = [
            {
                current_platform.device_control_env_var: ",".join(
                    map(str, node_gpus[node_id])
                ),
            }
            for (node_id, _) in worker_node_and_gpu_ids
        ]
325

326
        # Environment variables to copy from driver to workers
327
        env_vars_to_copy = get_env_vars_to_copy(
328
            exclude_vars=WORKER_SPECIFIC_ENV_VARS,
329
            additional_vars=set(current_platform.additional_env_vars),
330
331
            destination="workers",
        )
332

333
        # Copy existing env vars to each worker's args
334
335
        for args in all_args_to_update_environment_variables:
            # TODO: refactor platform-specific env vars
336
            for name in env_vars_to_copy:
337
338
                if name in os.environ:
                    args[name] = os.environ[name]
339

340
        self._env_vars_for_all_workers = all_args_to_update_environment_variables
341

342
343
        self.collective_rpc(
            "update_environment_variables", args=(self._get_env_vars_to_be_updated(),)
344
        )
345

346
347
348
349
350
351
352
353
354
355
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
356
        distributed_init_method = get_distributed_init_method(
357
358
            driver_ip, get_open_port()
        )
359

360
        # Initialize the actual workers inside worker wrapper.
361
362
363
364
365
366
        all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
            local_rank = node_workers[node_id].index(rank)
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
367
368
                rank=rank,
                distributed_init_method=distributed_init_method,
369
370
371
372
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
373
374
        self.collective_rpc("init_worker", args=(all_kwargs,))

375
376
377
378
        self.collective_rpc("init_device")
        if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
            self.collective_rpc("elastic_ep_execute", args=("load_model",))
        else:
379
            self.collective_rpc("load_model")
380

381
382
383
384
385
        def _update_block_size(worker):
            current_platform.update_block_size_for_backend(worker.vllm_config)

        self.collective_rpc(_update_block_size)

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        for pp_rank in range(self.parallel_config.pipeline_parallel_size):
            self.pp_tp_workers.append([])
            for tp_rank in range(self.parallel_config.tensor_parallel_size):
                # PP=2, TP=4
                # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
                rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank
                assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                assert pp_rank < len(self.pp_tp_workers)
                self.pp_tp_workers[pp_rank].append(self.workers[rank])

    def reinitialize_distributed(
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
        self.collective_rpc("reinitialize_distributed", args=(reconfig_request,))
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
            self.shutdown()

    def execute_model(  # type: ignore[override]
407
408
409
410
411
412
413
414
415
        self,
        scheduler_output: SchedulerOutput,
        non_block: bool = False,
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
        if self.scheduler_output is not None:
            raise RuntimeError(
                "State error: sample_tokens() must be called "
                "after execute_model() returns None."
            )
416

417
        if not self.uses_sampler or not scheduler_output.total_num_scheduled_tokens:
418
419
420
421
            # Model will not execute, call model runner immediately.
            return self._execute_dag(scheduler_output, None, non_block)

        # Model will execute, defer to sample_tokens() call.
422
423
424
425
426
427
428
        self.scheduler_output = scheduler_output
        return COMPLETED_NONE_FUTURE if non_block else None

    def sample_tokens(  # type: ignore[override]
        self,
        grammar_output: "GrammarOutput | None",
        non_block: bool = False,
429
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
430
        """Execute the model on the Ray workers.
431

432
433
434
        The scheduler output to use should have been provided in
        a prior call to execute_model().

435
        Args:
436
            grammar_output: The structured outputs grammar bitmask, if applicable.
437
            non_block: If True, the method will return a Future.
438

439
440
441
        Returns:
            The model runner output.
        """
442
443
        scheduler_output = self.scheduler_output
        if scheduler_output is None:
444
            return COMPLETED_NONE_FUTURE if non_block else None
445
446
447

        self.scheduler_output = None

448
449
450
451
452
453
454
        return self._execute_dag(scheduler_output, grammar_output, non_block)

    def _execute_dag(
        self,
        scheduler_output: SchedulerOutput,
        grammar_output: "GrammarOutput | None",
        non_block: bool = False,
455
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
456
457
458
        # Build the compiled DAG for the first time.
        if self.forward_dag is None:  # type: ignore
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
459

460
        refs = self.forward_dag.execute((scheduler_output, grammar_output))  # type: ignore
461

462
463
464
465
466
        if not self.has_connector:
            # Get output only from a single worker (output_rank)
            # When PP is not used, we block here until the result is available.
            if not non_block:
                return refs[0].get()
467

468
469
            # When PP is used, we return a FutureWrapper immediately so that
            # the scheduler can yield to the next batch.
470
            return FutureWrapper(refs[0])
471

472
473
474
475
        # Get output from all workers when connector is present
        assert self.kv_output_aggregator is not None
        if not non_block:
            # Block and get results from all workers
476
            return self.kv_output_aggregator.aggregate(ray.get(refs))
477

478
479
        # Return a future that will aggregate outputs from all workers
        return FutureWrapper(refs, self.kv_output_aggregator)
480

481
    def collective_rpc(  # type: ignore[override]
482
        self,
483
        method: str | Callable,
484
485
486
487
        timeout: float | None = None,
        args: tuple = (),
        kwargs: dict[str, Any] | None = None,
        non_block: bool = False,
488
    ) -> list[Any] | Future[list[Any]]:
489
        """Runs the given method on all workers."""
490
        sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
491
        del method
492

493
494
        if kwargs is None:
            kwargs = {}
495
        ray_worker_outputs = [
496
497
498
            worker.execute_method.remote(  # type: ignore[attr-defined]
                sent_method, *args, **kwargs
            )
499
            for worker in self.workers
500
        ]
501
502

        # Get the results of the ray workers.
503
        if non_block:
504
            return FutureWrapper(ray_worker_outputs)
505

506
        return ray.get(ray_worker_outputs, timeout=timeout)
507

508
    def _check_ray_cgraph_installation(self):
509
510
        import importlib.metadata

511
512
        from packaging import version

Rui Qiao's avatar
Rui Qiao committed
513
        required_version = version.parse("2.43.0")
514
        current_version = version.parse(importlib.metadata.version("ray"))
515
        if current_version < required_version:
516
517
518
519
            raise ValueError(
                f"Ray version {required_version} is "
                f"required, but found {current_version}"
            )
520

521
        import importlib.util
522
523

        cgraph_spec = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
524
        if cgraph_spec is None:
525
526
527
528
            raise ValueError(
                "Ray Compiled Graph is not installed. "
                "Run `pip install ray[cgraph]` to install it."
            )
529
530

        cupy_spec = importlib.util.find_spec("cupy")
531
        if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl":
532
533
            raise ValueError(
                "cupy is not installed but required since "
534
                "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
535
536
                "Run `pip install ray[cgraph]` and check cupy installation."
            )
537
538

    def _compiled_ray_dag(self, enable_asyncio: bool):
539
        assert self.parallel_config.use_ray
540
        self._check_ray_cgraph_installation()
541
542
543
544
545
546
547
548
        # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
        # (it is 10 seconds by default). This is a Ray environment variable to
        # control the timeout of getting result from a compiled graph execution,
        # i.e., the distributed execution that includes model forward runs and
        # intermediate tensor communications, in the case of vllm.
        # Note: we should set this env var before importing
        # ray.dag, otherwise it will not take effect.
        os.environ.setdefault("RAY_CGRAPH_get_timeout", "300")  # noqa: SIM112
549
        from ray.dag import InputNode, MultiOutputNode
550
551

        logger.info(
552
553
554
            "RAY_CGRAPH_get_timeout is set to %s",
            os.environ["RAY_CGRAPH_get_timeout"],  # noqa: SIM112
        )
555
556
557
558
559
560
561
562
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE,
        )
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
563
564
565
566
567

        channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
        if channel_type not in ("auto", "nccl", "shm"):
            raise ValueError(
                "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
568
569
                f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'."
            )
570

571
        with InputNode() as input_data:
572
            # Example DAG: PP=2, TP=4
573
574
575
576
577
            #
            # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput   # noqa: E501
578
579
580
581
582
583
584

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
585
586
587
588
                outputs = [
                    worker.execute_model_ray.bind(outputs[i])  # type: ignore[attr-defined]
                    for i, worker in enumerate(tp_group)
                ]
589
590

                last_pp_rank = len(self.pp_tp_workers) - 1
591
592
593
594
                if (
                    pp_rank < last_pp_rank
                    and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"
                ):
595
596
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
597
598
                    # pp stage or when using shared memory (the default).
                    transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
599
                    outputs = [
Rui Qiao's avatar
Rui Qiao committed
600
                        output.with_tensor_transport(transport=transport)
601
602
603
604
605
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

606
607
        if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
            from ray.experimental.channel.accelerator_context import (
608
609
                register_accelerator_context,
            )
610
611

            from vllm.distributed.device_communicators.ray_communicator import (
612
613
614
615
616
617
618
619
620
621
622
                RayPPCommunicator,
            )

            register_accelerator_context(
                torch_module_name="cuda", communicator_cls=RayPPCommunicator
            )
            logger.info(
                "Using RayPPCommunicator "
                "(which wraps vLLM _PP GroupCoordinator) "
                "for Ray Compiled Graph communication."
            )
623
        else:
624
625
626
            logger.info(
                "Using Ray's NCCL communicator for Ray Compiled Graph communication."
            )
627

628
629
        return forward_dag.experimental_compile(
            enable_asyncio=enable_asyncio,
630
631
            _overlap_gpu_communication=envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
632
633

    def __del__(self):
634
        self.shutdown()
635

636
637
638
639
    def check_health(self) -> None:
        # Assume that the Ray workers are healthy.
        # TODO: check the health of the Ray workers
        return