ray_executor.py 27.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
6
from collections import defaultdict, deque
from collections.abc import Callable, Sequence
7
from concurrent.futures import Future
8
from dataclasses import dataclass
9
from typing import TYPE_CHECKING, Any
10
from functools import partial
11

12
import cloudpickle
13

14
import vllm.envs as envs
15
from vllm.logger import init_logger
16
from vllm.platforms import current_platform
17
from vllm.ray.ray_env import get_env_vars_to_copy
18
from vllm.utils.network_utils import (
19
20
21
22
    get_distributed_init_method,
    get_ip,
    get_open_port,
)
23
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
24
25
26
27
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (
    FutureWrapper,
28
    NonBlockFutureWrapper,
29
30
31
32
33
    RayWorkerWrapper,
    initialize_ray_cluster,
    ray,
)
from vllm.v1.outputs import ModelRunnerOutput
34
35

if ray is not None:
36
    from ray.actor import ActorHandle
37
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
38
39
    from ray.util.queue import Queue as RayQueue
    from ray.util.queue import Empty as EmptyError
40
41
else:
    ActorHandle = None
42
43
44
45
46
47

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)

48
49
50
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
COMPLETED_NONE_FUTURE.set_result(None)

51

52
53
54
55
56
57
58
@dataclass
class RayWorkerMetaData:
    """
    Metadata for a Ray worker.
    The order of ray worker creation can be random,
    and we need to reset the rank after creating all workers.
    """
59

60
61
62
63
64
65
    worker: ActorHandle
    created_rank: int
    adjusted_rank: int = -1
    ip: str = ""


66
class RayDistributedExecutor(Executor):
67
68
69
70
71
    """Ray-based distributed executor"""

    # These env vars are worker-specific, therefore are NOT copied
    # from the driver to the workers
    WORKER_SPECIFIC_ENV_VARS = {
72
73
74
75
        "VLLM_HOST_IP",
        "VLLM_HOST_PORT",
        "LOCAL_RANK",
        "CUDA_VISIBLE_DEVICES",
76
77
    }

78
79
80
    # These non-vLLM env vars are copied from the driver to workers
    ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}

81
    uses_ray: bool = True
82
    supports_pp: bool = True
83

84
    def _init_executor(self) -> None:
85
        self.forward_dag: ray.dag.CompiledDAG | None = None
86
87
88
89

        # For TPU or XPU, avoid compiling NVIDIA's NCCL
        if current_platform.is_tpu() or current_platform.is_xpu():
            os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
90

91
92
93
        # KV connector setup
        self.has_connector = self.vllm_config.kv_transfer_config is not None

94
        assert self.uses_ray
95
        initialize_ray_cluster(self.parallel_config)
96
97
98
99
100
101
102
103
104
105
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

106
107
108
        self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
            self.vllm_config.ec_transfer_config is None
            or not self.vllm_config.ec_transfer_config.is_ec_producer
109
110
        )

111
112
        self.scheduler_output: SchedulerOutput | None = None

113
114
115
116
117
    @property
    def max_concurrent_batches(self) -> int:
        """Ray distributed executor supports pipeline parallelism,
        meaning that it allows PP size batches to be executed concurrently.
        """
118
119
        pp_size = self.parallel_config.pipeline_parallel_size
        return 2 if pp_size <= 1 and self.scheduler_config.async_scheduling else pp_size
120

121
    def shutdown(self) -> None:
122
123
124
125
126
        if logger:
            # Somehow logger can be None here.
            logger.info(
                "Shutting down Ray distributed executor. If you see error log "
                "from logging.cc regarding SIGTERM received, please ignore "
127
128
                "because this is the expected termination process in Ray."
            )
129
130
131
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
132

133
134
135
136
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

137
    def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]:
138
139
140
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
141
142
143
144
145
146
147
        runtime_env.update(
            {
                "nsight": {
                    "t": "cuda,cudnn,cublas",
                    "o": "'worker_process_%p'",
                    "cuda-graph-trace": "node",
                }
148
            }
149
        )
150
151
152

        return ray_remote_kwargs

153
154
155
156
    # child class could overwrite this to return actual env vars.
    def _get_env_vars_to_be_updated(self):
        return self._env_vars_for_all_workers

157
    def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
158
        num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
159
160
161

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
162
        self.driver_dummy_worker: RayWorkerWrapper | None = None
163
        # The remaining workers are the actual ray actors.
164
        self.workers: list[RayWorkerWrapper] = []
165

166
167
168
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
169
        self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
170

171
172
        self.output_rank = self._get_output_rank()

173
174
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
175
176
                ray_remote_kwargs
            )
177

178
        # Create the workers.
179
        bundle_indices: list[int]
180
181
        if envs.VLLM_RAY_BUNDLE_INDICES:
            # Use the bundle indices specified by the user.
182
183
184
185
186
187
188
189
190
191
            bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
            assert len(bundle_indices) == self.parallel_config.world_size, (
                "VLLM_RAY_BUNDLE_INDICES must have the same size"
                f" as the world size, but got {bundle_indices=} "
                f"and {self.parallel_config.world_size=}"
            )
            assert len(set(bundle_indices)) == len(bundle_indices), (
                "VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
                f" but got {bundle_indices=}"
            )
192
193
194
195
196
197
        else:
            # use the first N bundles that have GPU resources.
            bundle_indices = []
            for bundle_id, bundle in enumerate(placement_group.bundle_specs):
                if bundle.get(current_platform.ray_device_key, 0):
                    bundle_indices.append(bundle_id)
198
            bundle_indices = bundle_indices[: self.parallel_config.world_size]
199

200
        worker_metadata: list[RayWorkerMetaData] = []
201
        driver_ip = get_ip()
202
203
204

        self.response_mqs = [None] * len(bundle_indices)
        response_mqs_tmp = [None] * len(bundle_indices)
205
        for rank, bundle_id in enumerate(bundle_indices):
206
207
208
209
210
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
211

212
213
214
215
216
217
            # use queue to implement actor worker response output in async scheduling mode
            response_mq = None
            if self.scheduler_config.async_scheduling:
                response_mq = RayQueue(maxsize=256)
                response_mqs_tmp[rank] = response_mq

218
219
220
221
222
223
224
            if current_platform.ray_device_key == "GPU":
                # NV+AMD GPUs, and Intel XPUs
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
225
226
                )(RayWorkerWrapper).remote(use_async_scheduling=self.scheduler_config.async_scheduling,
                                           response_mq=response_mq, rpc_rank=rank)
227
228
229
230
231
232
233
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={current_platform.ray_device_key: num_gpus},
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
234
235
                )(RayWorkerWrapper).remote(use_async_scheduling=self.scheduler_config.async_scheduling,
                                           response_mq=response_mq, rpc_rank=rank)
236

237
            worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
238

239
240
241
242
243
244
        worker_ips = ray.get(
            [
                each.worker.get_node_ip.remote()  # type: ignore[attr-defined]
                for each in worker_metadata
            ]
        )
245
246
247

        for each, ip in zip(worker_metadata, worker_ips):
            each.ip = ip
248

249
        logger.debug("workers: %s", worker_metadata)
250
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
251

252
        ip_counts: dict[str, int] = {}
253
254
255
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

256
        def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
257
258
259
260
261
262
263
264
265
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
266
            ip = item.ip
267
            return 0 if ip == driver_ip else 1, ip_counts[ip], ip
268
269
270
271

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
272
273
274
        sorted_worker_metadata = sorted(
            worker_metadata, key=sort_by_driver_then_worker_ip
        )
275
        for i, item in enumerate(sorted_worker_metadata):
276
            item.adjusted_rank = i
277
278
        self.workers = [item.worker for item in sorted_worker_metadata]
        rerank_mapping = {
279
            item.created_rank: item.adjusted_rank for item in sorted_worker_metadata
280
        }
281
282
283
284
        self.collective_rpc("adjust_rank", args=(rerank_mapping, -1 if self.has_connector else self.output_rank))

        for created_rank, adjusted_rank in rerank_mapping.items():
            self.response_mqs[adjusted_rank] = response_mqs_tmp[created_rank]
285

286
        # Get the set of GPU IDs used on each node.
287
288
289
290
291
292
        worker_node_and_gpu_ids = []
        for worker in [self.driver_dummy_worker] + self.workers:
            if worker is None:
                # driver_dummy_worker can be None when using ray spmd worker.
                continue
            worker_node_and_gpu_ids.append(
293
                ray.get(worker.get_node_and_gpu_ids.remote())
294
            )  # type: ignore[attr-defined]
295

296
297
298
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

299
300
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
301
302
303
304
305
306
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
307
308
309
310
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

311
312
313
314
315
316
317
318
319
        all_ips = set(worker_ips + [driver_ip])
        n_ips = len(all_ips)
        n_nodes = len(node_workers)

        if n_nodes != n_ips:
            raise RuntimeError(
                f"Every node should have a unique IP address. Got {n_nodes}"
                f" nodes with node ids {list(node_workers.keys())} and "
                f"{n_ips} unique IP addresses {all_ips}. Please check your"
320
321
                " network configuration. If you set `VLLM_HOST_IP`"
                " environment variable, make sure it is unique for"
322
323
                " each node."
            )
324

325
        # Set environment variables for the driver and workers.
326
327
328
329
330
331
332
333
        all_args_to_update_environment_variables = [
            {
                current_platform.device_control_env_var: ",".join(
                    map(str, node_gpus[node_id])
                ),
            }
            for (node_id, _) in worker_node_and_gpu_ids
        ]
334

335
        # Environment variables to copy from driver to workers
336
337
        env_vars_to_copy = get_env_vars_to_copy(
            exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
338
            additional_vars=set(current_platform.additional_env_vars).union(
339
340
341
342
                self.ADDITIONAL_ENV_VARS
            ),
            destination="workers",
        )
343

344
        # Copy existing env vars to each worker's args
345
346
        for args in all_args_to_update_environment_variables:
            # TODO: refactor platform-specific env vars
347
            for name in env_vars_to_copy:
348
349
                if name in os.environ:
                    args[name] = os.environ[name]
350

351
        self._env_vars_for_all_workers = all_args_to_update_environment_variables
352

353
354
        self.collective_rpc(
            "update_environment_variables", args=(self._get_env_vars_to_be_updated(),)
355
        )
356

357
358
359
360
361
362
363
364
365
366
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
367
        distributed_init_method = get_distributed_init_method(
368
369
            driver_ip, get_open_port()
        )
370

371
        # Initialize the actual workers inside worker wrapper.
372
373
374
375
376
377
        all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
            local_rank = node_workers[node_id].index(rank)
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
378
379
                rank=rank,
                distributed_init_method=distributed_init_method,
380
381
382
383
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        self.collective_rpc("init_worker", args=(all_kwargs,))

        self.collective_rpc("init_device")
        self.collective_rpc("load_model")

        for pp_rank in range(self.parallel_config.pipeline_parallel_size):
            self.pp_tp_workers.append([])
            for tp_rank in range(self.parallel_config.tensor_parallel_size):
                # PP=2, TP=4
                # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
                rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank
                assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                assert pp_rank < len(self.pp_tp_workers)
                self.pp_tp_workers[pp_rank].append(self.workers[rank])

399
400
401
402
        if self.scheduler_config.async_scheduling:
            self.futures_queue = deque[tuple[NonBlockFutureWrapper, Callable]]()


403
404
405
406
407
408
409
410
411
412
413
    def reinitialize_distributed(
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
        self.collective_rpc("reinitialize_distributed", args=(reconfig_request,))
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
            self.shutdown()

    def execute_model(  # type: ignore[override]
414
415
416
417
418
419
420
421
422
        self,
        scheduler_output: SchedulerOutput,
        non_block: bool = False,
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
        if self.scheduler_output is not None:
            raise RuntimeError(
                "State error: sample_tokens() must be called "
                "after execute_model() returns None."
            )
423

424
        if not self.uses_sampler or not scheduler_output.total_num_scheduled_tokens:
425
426
427
428
            # Model will not execute, call model runner immediately.
            return self._execute_dag(scheduler_output, None, non_block)

        # Model will execute, defer to sample_tokens() call.
429
430
431
432
433
434
435
        self.scheduler_output = scheduler_output
        return COMPLETED_NONE_FUTURE if non_block else None

    def sample_tokens(  # type: ignore[override]
        self,
        grammar_output: "GrammarOutput | None",
        non_block: bool = False,
436
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
437
        """Execute the model on the Ray workers.
438

439
440
441
        The scheduler output to use should have been provided in
        a prior call to execute_model().

442
        Args:
443
            grammar_output: The structured outputs grammar bitmask, if applicable.
444
            non_block: If True, the method will return a Future.
445

446
447
448
        Returns:
            The model runner output.
        """
449
450
        scheduler_output = self.scheduler_output
        if scheduler_output is None:
451
            return COMPLETED_NONE_FUTURE if non_block else None
452
453
454

        self.scheduler_output = None

455
456
457
458
459
460
461
        return self._execute_dag(scheduler_output, grammar_output, non_block)

    def _execute_dag(
        self,
        scheduler_output: SchedulerOutput,
        grammar_output: "GrammarOutput | None",
        non_block: bool = False,
462
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
463
464
465
        # Build the compiled DAG for the first time.
        if self.forward_dag is None:  # type: ignore
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
466

467
        refs = self.forward_dag.execute((scheduler_output, grammar_output))  # type: ignore
468

469
470
471
472
473
474
        if not self.scheduler_config.async_scheduling:
            if not self.has_connector:
                # Get output only from a single worker (output_rank)
                # When PP is not used, we block here until the result is available.
                if not non_block:
                    return refs[0].get()
475

476
477
478
                # When PP is used, we return a FutureWrapper immediately so that
                # the scheduler can yield to the next batch.
                return FutureWrapper(refs[0])
479

480
481
482
483
484
            # Get output from all workers when connector is present
            assert self.kv_output_aggregator is not None
            if not non_block:
                # Block and get results from all workers
                return self.kv_output_aggregator.aggregate(ray.get(refs))
485

486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
            # Return a future that will aggregate outputs from all workers
            return FutureWrapper(refs, self.kv_output_aggregator)
        else:
            if self.has_connector:
                aggregate: Callable[[Any], Any] = partial(
                    self.kv_output_aggregator.aggregate, output_rank= self.output_rank
                )
            else:
                aggregate = lambda x: x

            output_rank = self.output_rank if not self.has_connector else None
            response_mqs: Sequence[RayQueue] = self.response_mqs
            if not self.has_connector:
                response_mqs = (response_mqs[self.output_rank],)

            def get_response():
                responses = []
                for mq in response_mqs:
                    try:
                        status, result = mq.get(timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
                    except EmptyError as e:
                        raise TimeoutError(f"ray exec timed out.") from e
                    if status != RayWorkerWrapper.ResponseStatus.SUCCESS:
                        raise RuntimeError(
                            f"Worker failed with error '{result}', please check the"
                            " stack trace above for the root cause"
                        )
                    responses.append(result)
                return responses[0] if output_rank is not None else responses

            future = NonBlockFutureWrapper(self.futures_queue, aggregate=aggregate)
            self.futures_queue.appendleft((future, get_response))

            return future
520

521
    def collective_rpc(  # type: ignore[override]
522
        self,
523
        method: str | Callable,
524
525
526
527
        timeout: float | None = None,
        args: tuple = (),
        kwargs: dict[str, Any] | None = None,
        non_block: bool = False,
528
    ) -> list[Any] | Future[list[Any]]:
529
        """Runs the given method on all workers."""
530
        sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
531
        del method
532

533
534
        if kwargs is None:
            kwargs = {}
535
        ray_worker_outputs = [
536
537
538
            worker.execute_method.remote(  # type: ignore[attr-defined]
                sent_method, *args, **kwargs
            )
539
            for worker in self.workers
540
        ]
541
542

        # Get the results of the ray workers.
543
        if non_block:
544
            return FutureWrapper(ray_worker_outputs)
545

546
        return ray.get(ray_worker_outputs, timeout=timeout)
547

548
    def _check_ray_cgraph_installation(self):
549
550
        import importlib.metadata

551
552
        from packaging import version

Rui Qiao's avatar
Rui Qiao committed
553
        required_version = version.parse("2.43.0")
554
        current_version = version.parse(importlib.metadata.version("ray"))
555
        if current_version < required_version:
556
557
558
559
            raise ValueError(
                f"Ray version {required_version} is "
                f"required, but found {current_version}"
            )
560

561
        import importlib.util
562
563

        cgraph_spec = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
564
        if cgraph_spec is None:
565
566
567
568
            raise ValueError(
                "Ray Compiled Graph is not installed. "
                "Run `pip install ray[cgraph]` to install it."
            )
569
570

        cupy_spec = importlib.util.find_spec("cupy")
571
        if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl":
572
573
            raise ValueError(
                "cupy is not installed but required since "
574
                "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
575
576
                "Run `pip install ray[cgraph]` and check cupy installation."
            )
577
578

    def _compiled_ray_dag(self, enable_asyncio: bool):
579
        assert self.parallel_config.use_ray
580
        self._check_ray_cgraph_installation()
581
582
583
584
585
586
587
588
        # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
        # (it is 10 seconds by default). This is a Ray environment variable to
        # control the timeout of getting result from a compiled graph execution,
        # i.e., the distributed execution that includes model forward runs and
        # intermediate tensor communications, in the case of vllm.
        # Note: we should set this env var before importing
        # ray.dag, otherwise it will not take effect.
        os.environ.setdefault("RAY_CGRAPH_get_timeout", "300")  # noqa: SIM112
589
        from ray.dag import InputNode, MultiOutputNode
590
591

        logger.info(
592
593
594
            "RAY_CGRAPH_get_timeout is set to %s",
            os.environ["RAY_CGRAPH_get_timeout"],  # noqa: SIM112
        )
595
596
597
598
599
600
601
602
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE,
        )
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
603
604
605
606
607

        channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
        if channel_type not in ("auto", "nccl", "shm"):
            raise ValueError(
                "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
608
609
                f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'."
            )
610

611
        with InputNode() as input_data:
612
            # Example DAG: PP=2, TP=4
613
614
615
616
617
            #
            # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput   # noqa: E501
618
619
620
621
622
623
624

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
625
626
627
628
                outputs = [
                    worker.execute_model_ray.bind(outputs[i])  # type: ignore[attr-defined]
                    for i, worker in enumerate(tp_group)
                ]
629
630

                last_pp_rank = len(self.pp_tp_workers) - 1
631
632
633
634
                if (
                    pp_rank < last_pp_rank
                    and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"
                ):
635
636
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
637
638
                    # pp stage or when using shared memory (the default).
                    transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
639
                    outputs = [
Rui Qiao's avatar
Rui Qiao committed
640
                        output.with_tensor_transport(transport=transport)
641
642
643
644
645
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

646
647
        if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
            from ray.experimental.channel.accelerator_context import (
648
649
                register_accelerator_context,
            )
650
651

            from vllm.distributed.device_communicators.ray_communicator import (
652
653
654
655
656
657
658
659
660
661
662
                RayPPCommunicator,
            )

            register_accelerator_context(
                torch_module_name="cuda", communicator_cls=RayPPCommunicator
            )
            logger.info(
                "Using RayPPCommunicator "
                "(which wraps vLLM _PP GroupCoordinator) "
                "for Ray Compiled Graph communication."
            )
663
        else:
664
665
666
            logger.info(
                "Using Ray's NCCL communicator for Ray Compiled Graph communication."
            )
667

668
669
        return forward_dag.experimental_compile(
            enable_asyncio=enable_asyncio,
670
671
            _overlap_gpu_communication=envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
672
673

    def __del__(self):
674
        self.shutdown()
675

676
677
678
679
    def check_health(self) -> None:
        # Assume that the Ray workers are healthy.
        # TODO: check the health of the Ray workers
        return
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    
    def _get_output_rank(self) -> int:
        # Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
        # (the first TP worker of the last PP stage).
        # Example:
        # Assuming TP=8, PP=4, then the world_size=32
        # 0-7, PP rank 0
        # 8-15, PP rank 1
        # 16-23, PP rank 2
        # 24-31, PP rank 3
        # so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
        return (
            self.parallel_config.world_size
            - self.parallel_config.tensor_parallel_size
            * self.parallel_config.prefill_context_parallel_size
        )