ray_distributed_executor.py 30.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import os
6
from collections import defaultdict
7
from collections.abc import Callable
8
from dataclasses import dataclass
9
from typing import TYPE_CHECKING, Any
10

11
import cloudpickle
12
13
import msgspec

14
import vllm.envs as envs
15
from vllm.executor.executor_base import DistributedExecutorBase
16
from vllm.executor.msgspec_utils import encode_hook
17
from vllm.executor.ray_utils import RayWorkerWrapper, initialize_ray_cluster, ray
18
from vllm.logger import init_logger
19
from vllm.platforms import current_platform
20
from vllm.ray.ray_env import get_env_vars_to_copy
21
from vllm.sequence import ExecuteModelRequest
22
23
24
25
26
27
28
from vllm.utils import (
    _run_task_with_lock,
    get_distributed_init_method,
    get_ip,
    get_open_port,
    make_async,
)
29
from vllm.v1.outputs import SamplerOutput
30
31

if ray is not None:
32
    from ray.actor import ActorHandle
33
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
34
35
else:
    ActorHandle = None
36
37
38
39
40
41
42

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)


43
44
45
46
47
48
49
@dataclass
class RayWorkerMetaData:
    """
    Metadata for a Ray worker.
    The order of ray worker creation can be random,
    and we need to reset the rank after creating all workers.
    """
50

51
52
53
54
55
56
57
    worker: ActorHandle
    created_rank: int
    adjusted_rank: int = -1
    ip: str = ""


class RayDistributedExecutor(DistributedExecutorBase):
58
59
60
61
62
    """Ray-based distributed executor"""

    # These env vars are worker-specific, therefore are NOT copied
    # from the driver to the workers
    WORKER_SPECIFIC_ENV_VARS = {
63
64
65
66
        "VLLM_HOST_IP",
        "VLLM_HOST_PORT",
        "LOCAL_RANK",
        "CUDA_VISIBLE_DEVICES",
67
68
    }

69
70
71
    # These non-vLLM env vars are copied from the driver to workers
    ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}

72
73
    uses_ray: bool = True

74
    def _init_executor(self) -> None:
75
        self.forward_dag: ray.dag.CompiledDAG | None = None
76
        if envs.VLLM_USE_V1:
77
            # V1 uses SPMD worker and compiled DAG
78
79
            os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
            os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
80

81
82
            # For TPU or XPU, avoid compiling NVIDIA's NCCL
            if current_platform.is_tpu() or current_platform.is_xpu():
83
                os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
84

85
86
87
88
89
90
91
92
93
94
95
96
        # If the env var is set, it uses the Ray's compiled DAG API
        # which optimizes the control plane overhead.
        # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
        # Currently, this requires USE_RAY_SPMD_WORKER=True.
        self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
        # If the env var is set, then we do not distinguish between the
        # "driver worker" vs other workers. Also, the rank 0 worker will
        # be executed in a remote Ray worker. Currently this requires
        # USE_RAY_COMPILED_DAG=True.
        self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
        if self.use_ray_compiled_dag:
            assert self.use_ray_spmd_worker, (
97
98
                "VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_RAY_SPMD_WORKER=1"
            )
99
100
101
        if self.use_ray_spmd_worker:
            # TODO: Support SPMD worker for non-DAG Ray executor.
            assert self.use_ray_compiled_dag, (
102
103
                "VLLM_USE_RAY_SPMD_WORKER=1 requires VLLM_USE_RAY_COMPILED_DAG=1"
            )
104

105
        assert self.uses_ray
106
        initialize_ray_cluster(self.parallel_config)
107
108
109
110
111
112
113
114
115
116
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

117
        self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
118
        self.output_decoder = msgspec.msgpack.Decoder(list[SamplerOutput] | None)
119
120
        self.use_v1 = envs.VLLM_USE_V1

121
        self.pp_locks: list[asyncio.Lock] | None = None
122
        if not self.use_ray_compiled_dag:
123
            self.driver_exec_method = make_async(self.driver_worker.execute_method)
124

125
    def shutdown(self) -> None:
126
127
128
129
130
        if logger:
            # Somehow logger can be None here.
            logger.info(
                "Shutting down Ray distributed executor. If you see error log "
                "from logging.cc regarding SIGTERM received, please ignore "
131
132
                "because this is the expected termination process in Ray."
            )
133
134
135
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
136

137
138
139
140
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

141
    def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]:
142
143
144
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
145
146
147
148
149
150
151
        runtime_env.update(
            {
                "nsight": {
                    "t": "cuda,cudnn,cublas",
                    "o": "'worker_process_%p'",
                    "cuda-graph-trace": "node",
                }
152
            }
153
        )
154
155
156

        return ray_remote_kwargs

157
158
159
160
    # child class could overwrite this to return actual env vars.
    def _get_env_vars_to_be_updated(self):
        return self._env_vars_for_all_workers

161
    def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
162
        num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
163
164
165

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
166
        self.driver_dummy_worker: RayWorkerWrapper | None = None
167
        # The remaining workers are the actual ray actors.
168
        self.workers: list[RayWorkerWrapper] = []
169

170
171
172
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
173
        self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
174

175
176
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
177
178
                ray_remote_kwargs
            )
179

180
        logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
181

182
        # Create the workers.
183
        bundle_indices: list[int]
184
185
        if envs.VLLM_RAY_BUNDLE_INDICES:
            # Use the bundle indices specified by the user.
186
187
188
189
190
191
192
193
194
195
            bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
            assert len(bundle_indices) == self.parallel_config.world_size, (
                "VLLM_RAY_BUNDLE_INDICES must have the same size"
                f" as the world size, but got {bundle_indices=} "
                f"and {self.parallel_config.world_size=}"
            )
            assert len(set(bundle_indices)) == len(bundle_indices), (
                "VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
                f" but got {bundle_indices=}"
            )
196
197
198
199
200
201
        else:
            # use the first N bundles that have GPU resources.
            bundle_indices = []
            for bundle_id, bundle in enumerate(placement_group.bundle_specs):
                if bundle.get(current_platform.ray_device_key, 0):
                    bundle_indices.append(bundle_id)
202
            bundle_indices = bundle_indices[: self.parallel_config.world_size]
203

204
        worker_metadata: list[RayWorkerMetaData] = []
205
206
        driver_ip = get_ip()
        for rank, bundle_id in enumerate(bundle_indices):
207
208
209
210
211
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
212

213
214
215
216
217
218
219
            if current_platform.ray_device_key == "GPU":
                # NV+AMD GPUs, and Intel XPUs
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
220
221
222
                )(RayWorkerWrapper).remote(  # type: ignore[attr-defined]
                    vllm_config=self.vllm_config, rpc_rank=rank
                )
223
224
225
226
227
228
229
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={current_platform.ray_device_key: num_gpus},
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
230
231
232
                )(RayWorkerWrapper).remote(  # type: ignore[attr-defined]
                    vllm_config=self.vllm_config, rpc_rank=rank
                )
233
            worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
234

235
236
237
238
239
240
        worker_ips = ray.get(
            [
                each.worker.get_node_ip.remote()  # type: ignore[attr-defined]
                for each in worker_metadata
            ]
        )
241
242
243

        for each, ip in zip(worker_metadata, worker_ips):
            each.ip = ip
244
245

        if not self.use_ray_spmd_worker:
246
247
248
249
            for i, each in enumerate(worker_metadata):
                # find and remove the dummy worker from the list
                worker = each.worker
                worker_ip = each.ip
250
                if self.driver_dummy_worker is None and worker_ip == driver_ip:
251
252
253
254
                    # If the worker is on the same node as the driver, we use it
                    # as the resource holder for the driver process.
                    self.driver_dummy_worker = worker
                    self.driver_worker = RayWorkerWrapper(
255
256
                        vllm_config=self.vllm_config, rpc_rank=0
                    )
257
                    worker_metadata.pop(i)
258
                    break
259

260
        logger.debug("workers: %s", worker_metadata)
261
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
262
        if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
263
            raise ValueError(
264
265
266
                "Ray does not allocate any GPUs on the driver node."
                f"Driver IP: {driver_ip}, worker IPs: {worker_ips}."
                "Consider adjusting the Ray placement group or running "
267
268
                "the driver on a GPU node."
            )
269

270
        ip_counts: dict[str, int] = {}
271
272
273
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

274
        def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
275
276
277
278
279
280
281
282
283
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
284
285
            ip = item.ip
            return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
286
287
288
289

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
290
291
292
        sorted_worker_metadata = sorted(
            worker_metadata, key=sort_by_driver_then_worker_ip
        )
293
294
295
296
297
        start_rank = 0 if self.use_ray_spmd_worker else 1
        for i, item in enumerate(sorted_worker_metadata):
            item.adjusted_rank = i + start_rank
        self.workers = [item.worker for item in sorted_worker_metadata]
        rerank_mapping = {
298
            item.created_rank: item.adjusted_rank for item in sorted_worker_metadata
299
300
        }
        self._run_workers("adjust_rank", rerank_mapping)
301

302
        # Get the set of GPU IDs used on each node.
303
304
305
306
307
308
        worker_node_and_gpu_ids = []
        for worker in [self.driver_dummy_worker] + self.workers:
            if worker is None:
                # driver_dummy_worker can be None when using ray spmd worker.
                continue
            worker_node_and_gpu_ids.append(
309
                ray.get(worker.get_node_and_gpu_ids.remote())
310
            )  # type: ignore[attr-defined]
311

312
313
314
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

315
316
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
317
318
319
320
321
322
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
323
324
325
326
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

327
328
329
330
331
332
333
334
335
        all_ips = set(worker_ips + [driver_ip])
        n_ips = len(all_ips)
        n_nodes = len(node_workers)

        if n_nodes != n_ips:
            raise RuntimeError(
                f"Every node should have a unique IP address. Got {n_nodes}"
                f" nodes with node ids {list(node_workers.keys())} and "
                f"{n_ips} unique IP addresses {all_ips}. Please check your"
336
337
                " network configuration. If you set `VLLM_HOST_IP`"
                " environment variable, make sure it is unique for"
338
339
                " each node."
            )
340

341
        # Set environment variables for the driver and workers.
342
343
344
345
346
347
348
349
        all_args_to_update_environment_variables = [
            {
                current_platform.device_control_env_var: ",".join(
                    map(str, node_gpus[node_id])
                ),
            }
            for (node_id, _) in worker_node_and_gpu_ids
        ]
350

351
        # Environment variables to copy from driver to workers
352
353
        env_vars_to_copy = get_env_vars_to_copy(
            exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
354
            additional_vars=set(current_platform.additional_env_vars).union(
355
356
357
358
                self.ADDITIONAL_ENV_VARS
            ),
            destination="workers",
        )
359

360
        # Copy existing env vars to each worker's args
361
362
        for args in all_args_to_update_environment_variables:
            # TODO: refactor platform-specific env vars
363
            for name in env_vars_to_copy:
364
365
                if name in os.environ:
                    args[name] = os.environ[name]
366

367
        self._env_vars_for_all_workers = all_args_to_update_environment_variables
368

369
370
371
        self._run_workers(
            "update_environment_variables", self._get_env_vars_to_be_updated()
        )
372

373
374
375
376
377
378
379
380
381
382
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
383
        distributed_init_method = get_distributed_init_method(
384
385
            driver_ip, get_open_port()
        )
386

387
        # Initialize the actual workers inside worker wrapper.
388
389
390
391
392
393
        all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
            local_rank = node_workers[node_id].index(rank)
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
394
395
                rank=rank,
                distributed_init_method=distributed_init_method,
396
397
398
399
400
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
        self._run_workers("init_worker", all_kwargs)
401

402
        self._run_workers("init_device")
403
404
405
406
        self._run_workers(
            "load_model",
            max_concurrent_workers=self.parallel_config.max_parallel_loading_workers,
        )
407

408
409
410
        if self.use_ray_spmd_worker:
            for pp_rank in range(self.parallel_config.pipeline_parallel_size):
                self.pp_tp_workers.append([])
411
                for tp_rank in range(self.parallel_config.tensor_parallel_size):
412
413
                    # PP=2, TP=4
                    # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
414
415
416
                    rank = (
                        pp_rank * self.parallel_config.tensor_parallel_size
                    ) + tp_rank
417
418
419
420
                    assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                    assert pp_rank < len(self.pp_tp_workers)
                    self.pp_tp_workers[pp_rank].append(self.workers[rank])

421
422
423
        # This is the list of workers that are rank 0 of each TP group EXCEPT
        # global rank 0. These are the workers that will broadcast to the
        # rest of the workers.
424
        self.tp_driver_workers: list[RayWorkerWrapper] = []
425
426
427
        # This is the list of workers that are not drivers and not the first
        # worker in a TP group. These are the workers that will be
        # broadcasted to.
428
        self.non_driver_workers: list[RayWorkerWrapper] = []
429

430
        # Enforce rank order for correct rank to return final output.
431
432
433
        for index, worker in enumerate(self.workers):
            # The driver worker is rank 0 and not in self.workers.
            rank = index + 1
434
            if rank % self.parallel_config.tensor_parallel_size == 0:
435
                self.tp_driver_workers.append(worker)
436
            else:
437
                self.non_driver_workers.append(worker)
438

439
    def _driver_execute_model(
440
441
        self, execute_model_req: ExecuteModelRequest | None
    ) -> list[SamplerOutput] | None:
442
        """Run execute_model in the driver worker.
443

444
445
446
        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
447
        assert not self.use_ray_spmd_worker, (
448
449
450
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
        )
        return self.driver_worker.execute_method("execute_model", execute_model_req)
451

452
    def execute_model(
453
        self, execute_model_req: ExecuteModelRequest
454
    ) -> list[SamplerOutput]:
455
456
457
458
459
460
        if not self.use_ray_spmd_worker:
            return super().execute_model(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)

461
462
463
464
        if self.use_v1:
            serialized_data = execute_model_req
        else:
            serialized_data = self.input_encoder.encode(execute_model_req)
465
        outputs = ray.get(self.forward_dag.execute(serialized_data))
466
        output = outputs[0] if self.use_v1 else self.output_decoder.decode(outputs[0])
467
        return output
468

469
470
    def _run_workers(
        self,
471
        method: str | Callable,
472
        *args,
473
        async_run_tensor_parallel_workers_only: bool = False,
474
        max_concurrent_workers: int | None = None,
475
476
        **kwargs,
    ) -> Any:
477
478
479
        """Runs the given method on all workers. Can be used in the following
        ways:

480
481
482
483
484
        Args:
        - async_run_tensor_parallel_workers_only: If True the method will be
          run only in the remote TP workers, not the driver worker.
          It will also be run asynchronously and return a list of futures
          rather than blocking on the results.
485
        - args/kwargs: All workers share the same args/kwargs
486
        """
487
        sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
488
        del method
489
490
        if self.use_ray_spmd_worker:
            assert not async_run_tensor_parallel_workers_only, (
491
492
                "async_run_tensor_parallel_workers_only is not supported for spmd mode."
            )
493
494

        if max_concurrent_workers:
495
            raise NotImplementedError("max_concurrent_workers is not supported yet.")
496

497
498
499
500
501
        # Start the ray workers first.
        ray_workers = self.workers
        if async_run_tensor_parallel_workers_only:
            ray_workers = self.non_driver_workers
        ray_worker_outputs = [
502
503
504
            worker.execute_method.remote(  # type: ignore[attr-defined]
                sent_method, *args, **kwargs
            )
505
            for worker in ray_workers
506
        ]
507

508
        if async_run_tensor_parallel_workers_only:
509
510
511
            # Just return futures
            return ray_worker_outputs

512
513
514
515
516
517
        driver_worker_output = []
        # In SPMD mode, the driver worker is the same as any other worker,
        # so we only explicitly execute on the driver worker if using a
        # non-SPMD worker class.
        if not self.use_ray_spmd_worker:
            # Start the driver worker after all the ray workers.
518
            driver_worker_output = [
519
                self.driver_worker.execute_method(sent_method, *args, **kwargs)
520
            ]
521

522
523
        # Get the results of the ray workers.
        if self.workers:
524
            ray_worker_outputs = ray.get(ray_worker_outputs)
525

526
        return driver_worker_output + ray_worker_outputs
527

528
529
530
531
532
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        ray.get(parallel_worker_tasks)

533
    def _check_ray_cgraph_installation(self):
534
535
        import importlib.metadata

536
537
        from packaging import version

Rui Qiao's avatar
Rui Qiao committed
538
        required_version = version.parse("2.43.0")
539
        current_version = version.parse(importlib.metadata.version("ray"))
540
        if current_version < required_version:
541
542
543
544
            raise ValueError(
                f"Ray version {required_version} is "
                f"required, but found {current_version}"
            )
545

546
        import importlib.util
547
548

        cgraph_spec = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
549
        if cgraph_spec is None:
550
551
552
553
            raise ValueError(
                "Ray Compiled Graph is not installed. "
                "Run `pip install ray[cgraph]` to install it."
            )
554
555

        cupy_spec = importlib.util.find_spec("cupy")
556
        if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl":
557
558
            raise ValueError(
                "cupy is not installed but required since "
559
                "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
560
561
                "Run `pip install ray[cgraph]` and check cupy installation."
            )
562
563

    def _compiled_ray_dag(self, enable_asyncio: bool):
564
        assert self.parallel_config.use_ray
565
        self._check_ray_cgraph_installation()
566
567
568
569
570
571
572
573
        # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
        # (it is 10 seconds by default). This is a Ray environment variable to
        # control the timeout of getting result from a compiled graph execution,
        # i.e., the distributed execution that includes model forward runs and
        # intermediate tensor communications, in the case of vllm.
        # Note: we should set this env var before importing
        # ray.dag, otherwise it will not take effect.
        os.environ.setdefault("RAY_CGRAPH_get_timeout", "300")  # noqa: SIM112
574
        from ray.dag import InputNode, MultiOutputNode
575
576

        logger.info(
577
578
579
            "RAY_CGRAPH_get_timeout is set to %s",
            os.environ["RAY_CGRAPH_get_timeout"],  # noqa: SIM112
        )
580
581
582
583
584
585
586
587
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE,
        )
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
588
589
590
591
592

        channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
        if channel_type not in ("auto", "nccl", "shm"):
            raise ValueError(
                "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
593
594
                f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'."
            )
595

596
        with InputNode() as input_data:
597
            # Example DAG: PP=2, TP=4
598
599
600
601
602
603
604
605
606
607
608
609
            #
            # For V0:
            # ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput   # noqa: E501
            #
            # For V1:
            # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput   # noqa: E501
610
611
612
613
614
615
616

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
617
618
                if self.use_v1:
                    outputs = [
619
620
621
622
                        worker.execute_model_ray.bind(  # type: ignore[attr-defined]
                            outputs[i]
                        )
                        for i, worker in enumerate(tp_group)
623
624
625
                    ]
                else:
                    outputs = [
626
627
628
629
                        worker.execute_model_spmd.bind(  # type: ignore[attr-defined]
                            outputs[i]
                        )
                        for i, worker in enumerate(tp_group)
630
                    ]
631
632

                last_pp_rank = len(self.pp_tp_workers) - 1
633
634
635
636
                if (
                    pp_rank < last_pp_rank
                    and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"
                ):
637
638
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
639
640
                    # pp stage or when using shared memory (the default).
                    transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
641
                    outputs = [
Rui Qiao's avatar
Rui Qiao committed
642
                        output.with_tensor_transport(transport=transport)
643
644
645
646
647
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

648
649
        if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
            from ray.experimental.channel.accelerator_context import (
650
651
                register_accelerator_context,
            )
652
653

            from vllm.distributed.device_communicators.ray_communicator import (
654
655
656
657
658
659
660
661
662
663
664
                RayPPCommunicator,
            )

            register_accelerator_context(
                torch_module_name="cuda", communicator_cls=RayPPCommunicator
            )
            logger.info(
                "Using RayPPCommunicator "
                "(which wraps vLLM _PP GroupCoordinator) "
                "for Ray Compiled Graph communication."
            )
665
        else:
666
667
668
            logger.info(
                "Using Ray's NCCL communicator for Ray Compiled Graph communication."
            )
669

670
671
        return forward_dag.experimental_compile(
            enable_asyncio=enable_asyncio,
672
673
            _overlap_gpu_communication=envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
674
675

    def __del__(self):
676
        self.shutdown()
677

678
    async def execute_model_async(
679
        self, execute_model_req: ExecuteModelRequest
680
    ) -> list[SamplerOutput]:
681
682
683
684
685
686
        if not self.use_ray_spmd_worker:
            return await super().execute_model_async(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)

687
688
        serialized_data = self.input_encoder.encode(execute_model_req)
        dag_future = await self.forward_dag.execute_async(serialized_data)
689
690
        output = await dag_future[0]
        return self.output_decoder.decode(output)
691

692
    async def _driver_execute_model_async(
693
        self, execute_model_req: ExecuteModelRequest | None = None
694
    ) -> list[SamplerOutput]:
695
        assert not self.use_ray_spmd_worker, (
696
697
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
        )
698
        if not self.tp_driver_workers:
699
            return await self.driver_exec_method("execute_model", execute_model_req)
700
701
702
703
704
705
706
707
708
        if self.pp_locks is None:
            # This locks each pipeline parallel stage so multiple virtual
            # engines can't execute on the same stage at the same time
            # We create the locks here to avoid creating them in the constructor
            # which uses a different asyncio loop.
            self.pp_locks = [
                asyncio.Lock()
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]
709

710
        tasks = [
711
            asyncio.create_task(
712
713
714
715
716
717
718
                _run_task_with_lock(
                    self.driver_exec_method,
                    self.pp_locks[0],
                    "execute_model",
                    execute_model_req,
                )
            )
719
        ]
720
        for pp_rank, driver_worker in enumerate(self.tp_driver_workers, start=1):
721
722
            tasks.append(
                asyncio.create_task(
723
                    _run_task_with_lock(
724
                        driver_worker.execute_method.remote,  # type: ignore[attr-defined]
725
726
727
728
729
730
                        self.pp_locks[pp_rank],
                        "execute_model",
                        execute_model_req,
                    )
                )
            )
731
732
733
734
735

        results = await asyncio.gather(*tasks)

        # Only the last PP stage has the final results.
        return results[-1]
736
737

    async def _start_worker_execution_loop(self):
738
        assert not self.use_ray_spmd_worker, (
739
740
            "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1"
        )
741
        coros = [
742
            worker.execute_method.remote("start_worker_execution_loop")  # type: ignore[attr-defined]
743
            for worker in self.non_driver_workers
744
745
        ]
        return await asyncio.gather(*coros)
746

747
748
749
750
    def check_health(self) -> None:
        # Assume that the Ray workers are healthy.
        # TODO: check the health of the Ray workers
        return