ray_distributed_executor.py 30.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import os
6
from collections import defaultdict
7
from collections.abc import Callable
8
from dataclasses import dataclass
9
from typing import TYPE_CHECKING, Any
10

11
import cloudpickle
12
13
import msgspec

14
import vllm.envs as envs
15
from vllm.executor.executor_base import DistributedExecutorBase
16
from vllm.executor.msgspec_utils import encode_hook
17
from vllm.executor.ray_utils import RayWorkerWrapper, initialize_ray_cluster, ray
18
from vllm.logger import init_logger
19
from vllm.platforms import current_platform
20
from vllm.ray.ray_env import get_env_vars_to_copy
21
from vllm.sequence import ExecuteModelRequest
22
23
from vllm.utils.asyncio import make_async
from vllm.utils.network_utils import (
24
25
26
27
    get_distributed_init_method,
    get_ip,
    get_open_port,
)
28
from vllm.v1.outputs import SamplerOutput
29
30

if ray is not None:
31
    from ray.actor import ActorHandle
32
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
33
34
else:
    ActorHandle = None
35
36
37
38
39
40
41

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)


42
43
44
45
46
47
48
@dataclass
class RayWorkerMetaData:
    """
    Metadata for a Ray worker.
    The order of ray worker creation can be random,
    and we need to reset the rank after creating all workers.
    """
49

50
51
52
53
54
55
56
    worker: ActorHandle
    created_rank: int
    adjusted_rank: int = -1
    ip: str = ""


class RayDistributedExecutor(DistributedExecutorBase):
57
58
59
60
61
    """Ray-based distributed executor"""

    # These env vars are worker-specific, therefore are NOT copied
    # from the driver to the workers
    WORKER_SPECIFIC_ENV_VARS = {
62
63
64
65
        "VLLM_HOST_IP",
        "VLLM_HOST_PORT",
        "LOCAL_RANK",
        "CUDA_VISIBLE_DEVICES",
66
67
    }

68
69
70
    # These non-vLLM env vars are copied from the driver to workers
    ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}

71
72
    uses_ray: bool = True

73
    def _init_executor(self) -> None:
74
        self.forward_dag: ray.dag.CompiledDAG | None = None
75
        if envs.VLLM_USE_V1:
76
            # V1 uses SPMD worker and compiled DAG
77
78
            os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
            os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
79

80
81
            # For TPU or XPU, avoid compiling NVIDIA's NCCL
            if current_platform.is_tpu() or current_platform.is_xpu():
82
                os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
83

84
85
86
87
88
89
90
91
92
93
94
95
        # If the env var is set, it uses the Ray's compiled DAG API
        # which optimizes the control plane overhead.
        # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
        # Currently, this requires USE_RAY_SPMD_WORKER=True.
        self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
        # If the env var is set, then we do not distinguish between the
        # "driver worker" vs other workers. Also, the rank 0 worker will
        # be executed in a remote Ray worker. Currently this requires
        # USE_RAY_COMPILED_DAG=True.
        self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
        if self.use_ray_compiled_dag:
            assert self.use_ray_spmd_worker, (
96
97
                "VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_RAY_SPMD_WORKER=1"
            )
98
99
100
        if self.use_ray_spmd_worker:
            # TODO: Support SPMD worker for non-DAG Ray executor.
            assert self.use_ray_compiled_dag, (
101
102
                "VLLM_USE_RAY_SPMD_WORKER=1 requires VLLM_USE_RAY_COMPILED_DAG=1"
            )
103

104
        assert self.uses_ray
105
        initialize_ray_cluster(self.parallel_config)
106
107
108
109
110
111
112
113
114
115
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

116
        self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
117
        self.output_decoder = msgspec.msgpack.Decoder(list[SamplerOutput] | None)
118
119
        self.use_v1 = envs.VLLM_USE_V1

120
        self.pp_locks: list[asyncio.Lock] | None = None
121
        if not self.use_ray_compiled_dag:
122
            self.driver_exec_method = make_async(self.driver_worker.execute_method)
123

124
    def shutdown(self) -> None:
125
126
127
128
129
        if logger:
            # Somehow logger can be None here.
            logger.info(
                "Shutting down Ray distributed executor. If you see error log "
                "from logging.cc regarding SIGTERM received, please ignore "
130
131
                "because this is the expected termination process in Ray."
            )
132
133
134
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
135

136
137
138
139
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

140
    def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]:
141
142
143
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
144
145
146
147
148
149
150
        runtime_env.update(
            {
                "nsight": {
                    "t": "cuda,cudnn,cublas",
                    "o": "'worker_process_%p'",
                    "cuda-graph-trace": "node",
                }
151
            }
152
        )
153
154
155

        return ray_remote_kwargs

156
157
158
159
    # child class could overwrite this to return actual env vars.
    def _get_env_vars_to_be_updated(self):
        return self._env_vars_for_all_workers

160
    def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
161
        num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
162
163
164

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
165
        self.driver_dummy_worker: RayWorkerWrapper | None = None
166
        # The remaining workers are the actual ray actors.
167
        self.workers: list[RayWorkerWrapper] = []
168

169
170
171
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
172
        self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
173

174
175
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
176
177
                ray_remote_kwargs
            )
178

179
        logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
180

181
        # Create the workers.
182
        bundle_indices: list[int]
183
184
        if envs.VLLM_RAY_BUNDLE_INDICES:
            # Use the bundle indices specified by the user.
185
186
187
188
189
190
191
192
193
194
            bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
            assert len(bundle_indices) == self.parallel_config.world_size, (
                "VLLM_RAY_BUNDLE_INDICES must have the same size"
                f" as the world size, but got {bundle_indices=} "
                f"and {self.parallel_config.world_size=}"
            )
            assert len(set(bundle_indices)) == len(bundle_indices), (
                "VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
                f" but got {bundle_indices=}"
            )
195
196
197
198
199
200
        else:
            # use the first N bundles that have GPU resources.
            bundle_indices = []
            for bundle_id, bundle in enumerate(placement_group.bundle_specs):
                if bundle.get(current_platform.ray_device_key, 0):
                    bundle_indices.append(bundle_id)
201
            bundle_indices = bundle_indices[: self.parallel_config.world_size]
202

203
        worker_metadata: list[RayWorkerMetaData] = []
204
205
        driver_ip = get_ip()
        for rank, bundle_id in enumerate(bundle_indices):
206
207
208
209
210
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
211

212
213
214
215
216
217
218
            if current_platform.ray_device_key == "GPU":
                # NV+AMD GPUs, and Intel XPUs
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
219
220
221
                )(RayWorkerWrapper).remote(  # type: ignore[attr-defined]
                    vllm_config=self.vllm_config, rpc_rank=rank
                )
222
223
224
225
226
227
228
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={current_platform.ray_device_key: num_gpus},
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
229
230
231
                )(RayWorkerWrapper).remote(  # type: ignore[attr-defined]
                    vllm_config=self.vllm_config, rpc_rank=rank
                )
232
            worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
233

234
235
236
237
238
239
        worker_ips = ray.get(
            [
                each.worker.get_node_ip.remote()  # type: ignore[attr-defined]
                for each in worker_metadata
            ]
        )
240
241
242

        for each, ip in zip(worker_metadata, worker_ips):
            each.ip = ip
243
244

        if not self.use_ray_spmd_worker:
245
246
247
248
            for i, each in enumerate(worker_metadata):
                # find and remove the dummy worker from the list
                worker = each.worker
                worker_ip = each.ip
249
                if self.driver_dummy_worker is None and worker_ip == driver_ip:
250
251
252
253
                    # If the worker is on the same node as the driver, we use it
                    # as the resource holder for the driver process.
                    self.driver_dummy_worker = worker
                    self.driver_worker = RayWorkerWrapper(
254
255
                        vllm_config=self.vllm_config, rpc_rank=0
                    )
256
                    worker_metadata.pop(i)
257
                    break
258

259
        logger.debug("workers: %s", worker_metadata)
260
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
261
        if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
262
            raise ValueError(
263
264
265
                "Ray does not allocate any GPUs on the driver node."
                f"Driver IP: {driver_ip}, worker IPs: {worker_ips}."
                "Consider adjusting the Ray placement group or running "
266
267
                "the driver on a GPU node."
            )
268

269
        ip_counts: dict[str, int] = {}
270
271
272
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

273
        def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
274
275
276
277
278
279
280
281
282
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
283
284
            ip = item.ip
            return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
285
286
287
288

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
289
290
291
        sorted_worker_metadata = sorted(
            worker_metadata, key=sort_by_driver_then_worker_ip
        )
292
293
294
295
296
        start_rank = 0 if self.use_ray_spmd_worker else 1
        for i, item in enumerate(sorted_worker_metadata):
            item.adjusted_rank = i + start_rank
        self.workers = [item.worker for item in sorted_worker_metadata]
        rerank_mapping = {
297
            item.created_rank: item.adjusted_rank for item in sorted_worker_metadata
298
299
        }
        self._run_workers("adjust_rank", rerank_mapping)
300

301
        # Get the set of GPU IDs used on each node.
302
303
304
305
306
307
        worker_node_and_gpu_ids = []
        for worker in [self.driver_dummy_worker] + self.workers:
            if worker is None:
                # driver_dummy_worker can be None when using ray spmd worker.
                continue
            worker_node_and_gpu_ids.append(
308
                ray.get(worker.get_node_and_gpu_ids.remote())
309
            )  # type: ignore[attr-defined]
310

311
312
313
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

314
315
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
316
317
318
319
320
321
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
322
323
324
325
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

326
327
328
329
330
331
332
333
334
        all_ips = set(worker_ips + [driver_ip])
        n_ips = len(all_ips)
        n_nodes = len(node_workers)

        if n_nodes != n_ips:
            raise RuntimeError(
                f"Every node should have a unique IP address. Got {n_nodes}"
                f" nodes with node ids {list(node_workers.keys())} and "
                f"{n_ips} unique IP addresses {all_ips}. Please check your"
335
336
                " network configuration. If you set `VLLM_HOST_IP`"
                " environment variable, make sure it is unique for"
337
338
                " each node."
            )
339

340
        # Set environment variables for the driver and workers.
341
342
343
344
345
346
347
348
        all_args_to_update_environment_variables = [
            {
                current_platform.device_control_env_var: ",".join(
                    map(str, node_gpus[node_id])
                ),
            }
            for (node_id, _) in worker_node_and_gpu_ids
        ]
349

350
        # Environment variables to copy from driver to workers
351
352
        env_vars_to_copy = get_env_vars_to_copy(
            exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
353
            additional_vars=set(current_platform.additional_env_vars).union(
354
355
356
357
                self.ADDITIONAL_ENV_VARS
            ),
            destination="workers",
        )
358

359
        # Copy existing env vars to each worker's args
360
361
        for args in all_args_to_update_environment_variables:
            # TODO: refactor platform-specific env vars
362
            for name in env_vars_to_copy:
363
364
                if name in os.environ:
                    args[name] = os.environ[name]
365

366
        self._env_vars_for_all_workers = all_args_to_update_environment_variables
367

368
369
370
        self._run_workers(
            "update_environment_variables", self._get_env_vars_to_be_updated()
        )
371

372
373
374
375
376
377
378
379
380
381
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
382
        distributed_init_method = get_distributed_init_method(
383
384
            driver_ip, get_open_port()
        )
385

386
        # Initialize the actual workers inside worker wrapper.
387
388
389
390
391
392
        all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
            local_rank = node_workers[node_id].index(rank)
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
393
394
                rank=rank,
                distributed_init_method=distributed_init_method,
395
396
397
398
399
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
        self._run_workers("init_worker", all_kwargs)
400

401
        self._run_workers("init_device")
402
403
404
405
        self._run_workers(
            "load_model",
            max_concurrent_workers=self.parallel_config.max_parallel_loading_workers,
        )
406

407
408
409
        if self.use_ray_spmd_worker:
            for pp_rank in range(self.parallel_config.pipeline_parallel_size):
                self.pp_tp_workers.append([])
410
                for tp_rank in range(self.parallel_config.tensor_parallel_size):
411
412
                    # PP=2, TP=4
                    # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
413
414
415
                    rank = (
                        pp_rank * self.parallel_config.tensor_parallel_size
                    ) + tp_rank
416
417
418
419
                    assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                    assert pp_rank < len(self.pp_tp_workers)
                    self.pp_tp_workers[pp_rank].append(self.workers[rank])

420
421
422
        # This is the list of workers that are rank 0 of each TP group EXCEPT
        # global rank 0. These are the workers that will broadcast to the
        # rest of the workers.
423
        self.tp_driver_workers: list[RayWorkerWrapper] = []
424
425
426
        # This is the list of workers that are not drivers and not the first
        # worker in a TP group. These are the workers that will be
        # broadcasted to.
427
        self.non_driver_workers: list[RayWorkerWrapper] = []
428

429
        # Enforce rank order for correct rank to return final output.
430
431
432
        for index, worker in enumerate(self.workers):
            # The driver worker is rank 0 and not in self.workers.
            rank = index + 1
433
            if rank % self.parallel_config.tensor_parallel_size == 0:
434
                self.tp_driver_workers.append(worker)
435
            else:
436
                self.non_driver_workers.append(worker)
437

438
    def _driver_execute_model(
439
440
        self, execute_model_req: ExecuteModelRequest | None
    ) -> list[SamplerOutput] | None:
441
        """Run execute_model in the driver worker.
442

443
444
445
        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
446
        assert not self.use_ray_spmd_worker, (
447
448
449
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
        )
        return self.driver_worker.execute_method("execute_model", execute_model_req)
450

451
    def execute_model(
452
        self, execute_model_req: ExecuteModelRequest
453
    ) -> list[SamplerOutput]:
454
455
456
457
458
459
        if not self.use_ray_spmd_worker:
            return super().execute_model(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)

460
461
462
463
        if self.use_v1:
            serialized_data = execute_model_req
        else:
            serialized_data = self.input_encoder.encode(execute_model_req)
464
        outputs = ray.get(self.forward_dag.execute(serialized_data))
465
        output = outputs[0] if self.use_v1 else self.output_decoder.decode(outputs[0])
466
        return output
467

468
469
    def _run_workers(
        self,
470
        method: str | Callable,
471
        *args,
472
        async_run_tensor_parallel_workers_only: bool = False,
473
        max_concurrent_workers: int | None = None,
474
475
        **kwargs,
    ) -> Any:
476
477
478
        """Runs the given method on all workers. Can be used in the following
        ways:

479
480
481
482
483
        Args:
        - async_run_tensor_parallel_workers_only: If True the method will be
          run only in the remote TP workers, not the driver worker.
          It will also be run asynchronously and return a list of futures
          rather than blocking on the results.
484
        - args/kwargs: All workers share the same args/kwargs
485
        """
486
        sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
487
        del method
488
489
        if self.use_ray_spmd_worker:
            assert not async_run_tensor_parallel_workers_only, (
490
491
                "async_run_tensor_parallel_workers_only is not supported for spmd mode."
            )
492
493

        if max_concurrent_workers:
494
            raise NotImplementedError("max_concurrent_workers is not supported yet.")
495

496
497
498
499
500
        # Start the ray workers first.
        ray_workers = self.workers
        if async_run_tensor_parallel_workers_only:
            ray_workers = self.non_driver_workers
        ray_worker_outputs = [
501
502
503
            worker.execute_method.remote(  # type: ignore[attr-defined]
                sent_method, *args, **kwargs
            )
504
            for worker in ray_workers
505
        ]
506

507
        if async_run_tensor_parallel_workers_only:
508
509
510
            # Just return futures
            return ray_worker_outputs

511
512
513
514
515
516
        driver_worker_output = []
        # In SPMD mode, the driver worker is the same as any other worker,
        # so we only explicitly execute on the driver worker if using a
        # non-SPMD worker class.
        if not self.use_ray_spmd_worker:
            # Start the driver worker after all the ray workers.
517
            driver_worker_output = [
518
                self.driver_worker.execute_method(sent_method, *args, **kwargs)
519
            ]
520

521
522
        # Get the results of the ray workers.
        if self.workers:
523
            ray_worker_outputs = ray.get(ray_worker_outputs)
524

525
        return driver_worker_output + ray_worker_outputs
526

527
528
529
530
531
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        ray.get(parallel_worker_tasks)

532
    def _check_ray_cgraph_installation(self):
533
534
        import importlib.metadata

535
536
        from packaging import version

Rui Qiao's avatar
Rui Qiao committed
537
        required_version = version.parse("2.43.0")
538
        current_version = version.parse(importlib.metadata.version("ray"))
539
        if current_version < required_version:
540
541
542
543
            raise ValueError(
                f"Ray version {required_version} is "
                f"required, but found {current_version}"
            )
544

545
        import importlib.util
546
547

        cgraph_spec = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
548
        if cgraph_spec is None:
549
550
551
552
            raise ValueError(
                "Ray Compiled Graph is not installed. "
                "Run `pip install ray[cgraph]` to install it."
            )
553
554

        cupy_spec = importlib.util.find_spec("cupy")
555
        if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl":
556
557
            raise ValueError(
                "cupy is not installed but required since "
558
                "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
559
560
                "Run `pip install ray[cgraph]` and check cupy installation."
            )
561
562

    def _compiled_ray_dag(self, enable_asyncio: bool):
563
        assert self.parallel_config.use_ray
564
        self._check_ray_cgraph_installation()
565
566
567
568
569
570
571
572
        # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
        # (it is 10 seconds by default). This is a Ray environment variable to
        # control the timeout of getting result from a compiled graph execution,
        # i.e., the distributed execution that includes model forward runs and
        # intermediate tensor communications, in the case of vllm.
        # Note: we should set this env var before importing
        # ray.dag, otherwise it will not take effect.
        os.environ.setdefault("RAY_CGRAPH_get_timeout", "300")  # noqa: SIM112
573
        from ray.dag import InputNode, MultiOutputNode
574
575

        logger.info(
576
577
578
            "RAY_CGRAPH_get_timeout is set to %s",
            os.environ["RAY_CGRAPH_get_timeout"],  # noqa: SIM112
        )
579
580
581
582
583
584
585
586
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE,
        )
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
587
588
589
590
591

        channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
        if channel_type not in ("auto", "nccl", "shm"):
            raise ValueError(
                "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
592
593
                f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'."
            )
594

595
        with InputNode() as input_data:
596
            # Example DAG: PP=2, TP=4
597
598
599
600
601
602
603
604
605
606
607
608
            #
            # For V0:
            # ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput   # noqa: E501
            #
            # For V1:
            # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput   # noqa: E501
609
610
611
612
613
614
615

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
616
617
                if self.use_v1:
                    outputs = [
618
619
620
621
                        worker.execute_model_ray.bind(  # type: ignore[attr-defined]
                            outputs[i]
                        )
                        for i, worker in enumerate(tp_group)
622
623
624
                    ]
                else:
                    outputs = [
625
626
627
628
                        worker.execute_model_spmd.bind(  # type: ignore[attr-defined]
                            outputs[i]
                        )
                        for i, worker in enumerate(tp_group)
629
                    ]
630
631

                last_pp_rank = len(self.pp_tp_workers) - 1
632
633
634
635
                if (
                    pp_rank < last_pp_rank
                    and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"
                ):
636
637
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
638
639
                    # pp stage or when using shared memory (the default).
                    transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
640
                    outputs = [
Rui Qiao's avatar
Rui Qiao committed
641
                        output.with_tensor_transport(transport=transport)
642
643
644
645
646
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

647
648
        if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
            from ray.experimental.channel.accelerator_context import (
649
650
                register_accelerator_context,
            )
651
652

            from vllm.distributed.device_communicators.ray_communicator import (
653
654
655
656
657
658
659
660
661
662
663
                RayPPCommunicator,
            )

            register_accelerator_context(
                torch_module_name="cuda", communicator_cls=RayPPCommunicator
            )
            logger.info(
                "Using RayPPCommunicator "
                "(which wraps vLLM _PP GroupCoordinator) "
                "for Ray Compiled Graph communication."
            )
664
        else:
665
666
667
            logger.info(
                "Using Ray's NCCL communicator for Ray Compiled Graph communication."
            )
668

669
670
        return forward_dag.experimental_compile(
            enable_asyncio=enable_asyncio,
671
672
            _overlap_gpu_communication=envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
673
674

    def __del__(self):
675
        self.shutdown()
676

677
    async def execute_model_async(
678
        self, execute_model_req: ExecuteModelRequest
679
    ) -> list[SamplerOutput]:
680
681
682
683
684
685
        if not self.use_ray_spmd_worker:
            return await super().execute_model_async(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)

686
687
        serialized_data = self.input_encoder.encode(execute_model_req)
        dag_future = await self.forward_dag.execute_async(serialized_data)
688
689
        output = await dag_future[0]
        return self.output_decoder.decode(output)
690

691
    async def _driver_execute_model_async(
692
        self, execute_model_req: ExecuteModelRequest | None = None
693
    ) -> list[SamplerOutput]:
694
        assert not self.use_ray_spmd_worker, (
695
696
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
        )
697
        if not self.tp_driver_workers:
698
            return await self.driver_exec_method("execute_model", execute_model_req)
699
700
701
702
703
704
705
706
707
        if self.pp_locks is None:
            # This locks each pipeline parallel stage so multiple virtual
            # engines can't execute on the same stage at the same time
            # We create the locks here to avoid creating them in the constructor
            # which uses a different asyncio loop.
            self.pp_locks = [
                asyncio.Lock()
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]
708

709
        tasks = [
710
            asyncio.create_task(
711
712
713
714
715
716
717
                _run_task_with_lock(
                    self.driver_exec_method,
                    self.pp_locks[0],
                    "execute_model",
                    execute_model_req,
                )
            )
718
        ]
719
        for pp_rank, driver_worker in enumerate(self.tp_driver_workers, start=1):
720
721
            tasks.append(
                asyncio.create_task(
722
                    _run_task_with_lock(
723
                        driver_worker.execute_method.remote,  # type: ignore[attr-defined]
724
725
726
727
728
729
                        self.pp_locks[pp_rank],
                        "execute_model",
                        execute_model_req,
                    )
                )
            )
730
731
732
733
734

        results = await asyncio.gather(*tasks)

        # Only the last PP stage has the final results.
        return results[-1]
735
736

    async def _start_worker_execution_loop(self):
737
        assert not self.use_ray_spmd_worker, (
738
739
            "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1"
        )
740
        coros = [
741
            worker.execute_method.remote("start_worker_execution_loop")  # type: ignore[attr-defined]
742
            for worker in self.non_driver_workers
743
744
        ]
        return await asyncio.gather(*coros)
745

746
747
748
749
    def check_health(self) -> None:
        # Assume that the Ray workers are healthy.
        # TODO: check the health of the Ray workers
        return
750
751
752
753
754
755


async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs):
    """Utility function to run async task in a lock"""
    async with lock:
        return await task(*args, **kwargs)