ray_distributed_executor.py 30.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import os
6
from collections import defaultdict
7
from dataclasses import dataclass
8
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
9

10
import cloudpickle
11
12
import msgspec

13
import vllm.envs as envs
14
15
from vllm.executor.executor_base import (
    DistributedExecutorBase)  # yapf: disable
16
from vllm.executor.msgspec_utils import encode_hook
17
18
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
                                     ray)
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.sampler import SamplerOutput
21
from vllm.platforms import current_platform
22
from vllm.ray.ray_env import get_env_vars_to_copy
23
from vllm.sequence import ExecuteModelRequest
24
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
25
                        get_ip, get_open_port, make_async)
26
27

if ray is not None:
28
    from ray.actor import ActorHandle
29
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
30
31
else:
    ActorHandle = None
32
33
34
35
36
37
38

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)


39
40
41
42
43
44
45
46
47
48
49
50
51
52
@dataclass
class RayWorkerMetaData:
    """
    Metadata for a Ray worker.
    The order of ray worker creation can be random,
    and we need to reset the rank after creating all workers.
    """
    worker: ActorHandle
    created_rank: int
    adjusted_rank: int = -1
    ip: str = ""


class RayDistributedExecutor(DistributedExecutorBase):
53
54
55
56
57
58
59
60
    """Ray-based distributed executor"""

    # These env vars are worker-specific, therefore are NOT copied
    # from the driver to the workers
    WORKER_SPECIFIC_ENV_VARS = {
        "VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES"
    }

61
62
63
    # These non-vLLM env vars are copied from the driver to workers
    ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}

64
65
    uses_ray: bool = True

66
    def _init_executor(self) -> None:
67
        self.forward_dag: Optional[ray.dag.CompiledDAG] = None
68
        if envs.VLLM_USE_V1:
69
            # V1 uses SPMD worker and compiled DAG
70
71
            os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
            os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
72

73
74
            # For TPU or XPU, avoid compiling NVIDIA's NCCL
            if current_platform.is_tpu() or current_platform.is_xpu():
75
                os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        # If the env var is set, it uses the Ray's compiled DAG API
        # which optimizes the control plane overhead.
        # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
        # Currently, this requires USE_RAY_SPMD_WORKER=True.
        self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
        # If the env var is set, then we do not distinguish between the
        # "driver worker" vs other workers. Also, the rank 0 worker will
        # be executed in a remote Ray worker. Currently this requires
        # USE_RAY_COMPILED_DAG=True.
        self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
        if self.use_ray_compiled_dag:
            assert self.use_ray_spmd_worker, (
                "VLLM_USE_RAY_COMPILED_DAG=1 requires "
                "VLLM_USE_RAY_SPMD_WORKER=1")
        if self.use_ray_spmd_worker:
            # TODO: Support SPMD worker for non-DAG Ray executor.
            assert self.use_ray_compiled_dag, (
                "VLLM_USE_RAY_SPMD_WORKER=1 requires "
                "VLLM_USE_RAY_COMPILED_DAG=1")

97
        assert self.uses_ray
98
        initialize_ray_cluster(self.parallel_config)
99
100
101
102
103
104
105
106
107
108
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

109
110
111
        self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
        self.output_decoder = msgspec.msgpack.Decoder(
            Optional[List[SamplerOutput]])
112
113
114
115
116
117
        self.use_v1 = envs.VLLM_USE_V1

        self.pp_locks: Optional[List[asyncio.Lock]] = None
        if not self.use_ray_compiled_dag:
            self.driver_exec_method = make_async(
                self.driver_worker.execute_method)
118

119
    def shutdown(self) -> None:
120
121
122
123
        logger.info(
            "Shutting down Ray distributed executor. If you see error log "
            "from logging.cc regarding SIGTERM received, please ignore because "
            "this is the expected termination process in Ray.")
124
125
126
127
128
129
130
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    def _configure_ray_workers_use_nsight(self,
                                          ray_remote_kwargs) -> Dict[str, Any]:
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
        runtime_env.update({
            "nsight": {
                "t": "cuda,cudnn,cublas",
                "o": "'worker_process_%p'",
                "cuda-graph-trace": "node",
            }
        })

        return ray_remote_kwargs

146
147
148
149
    # child class could overwrite this to return actual env vars.
    def _get_env_vars_to_be_updated(self):
        return self._env_vars_for_all_workers

150
151
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
152
        num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
153
154
155

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
156
        self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
157
        # The remaining workers are the actual ray actors.
158
        self.workers: List[RayWorkerWrapper] = []
159

160
161
162
163
164
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
        self.pp_tp_workers: List[List[RayWorkerWrapper]] = []

165
166
167
168
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
                ray_remote_kwargs)

169
        logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
170

171
        # Create the workers.
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        bundle_indices: List[int]
        if envs.VLLM_RAY_BUNDLE_INDICES:
            # Use the bundle indices specified by the user.
            bundle_indices = list(
                map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
            assert len(bundle_indices) == self.parallel_config.world_size, \
            ("VLLM_RAY_BUNDLE_INDICES must have the same size"
            f" as the world size, but got {bundle_indices=} "
            f"and {self.parallel_config.world_size=}")
            assert len(set(bundle_indices)) == len(bundle_indices), \
            ("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
            f" but got {bundle_indices=}")
        else:
            # use the first N bundles that have GPU resources.
            bundle_indices = []
            for bundle_id, bundle in enumerate(placement_group.bundle_specs):
                if bundle.get(current_platform.ray_device_key, 0):
                    bundle_indices.append(bundle_id)
            bundle_indices = bundle_indices[:self.parallel_config.world_size]

192
        worker_metadata: List[RayWorkerMetaData] = []
193
194
        driver_ip = get_ip()
        for rank, bundle_id in enumerate(bundle_indices):
195
196
197
198
199
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
200

201
202
203
204
205
206
207
208
            if current_platform.ray_device_key == "GPU":
                # NV+AMD GPUs, and Intel XPUs
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
                )(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
209
                                           rpc_rank=rank)
210
211
212
213
214
215
216
217
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={current_platform.ray_device_key: num_gpus},
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
                )(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
218
                                           rpc_rank=rank)
219
220
221
222
223
224
225
226
227
228
            worker_metadata.append(
                RayWorkerMetaData(worker=worker, created_rank=rank))

        worker_ips = ray.get([
            each.worker.get_node_ip.remote()  # type: ignore[attr-defined]
            for each in worker_metadata
        ])

        for each, ip in zip(worker_metadata, worker_ips):
            each.ip = ip
229
230

        if not self.use_ray_spmd_worker:
231
232
233
234
            for i, each in enumerate(worker_metadata):
                # find and remove the dummy worker from the list
                worker = each.worker
                worker_ip = each.ip
235
                if self.driver_dummy_worker is None and worker_ip == driver_ip:
236
237
238
239
                    # If the worker is on the same node as the driver, we use it
                    # as the resource holder for the driver process.
                    self.driver_dummy_worker = worker
                    self.driver_worker = RayWorkerWrapper(
240
                        vllm_config=self.vllm_config, rpc_rank=0)
241
                    worker_metadata.pop(i)
242
                    break
243

244
        logger.debug("workers: %s", worker_metadata)
245
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
246
        if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
247
            raise ValueError(
248
249
250
251
                "Ray does not allocate any GPUs on the driver node."
                f"Driver IP: {driver_ip}, worker IPs: {worker_ips}."
                "Consider adjusting the Ray placement group or running "
                "the driver on a GPU node.")
252

253
254
255
256
        ip_counts: Dict[str, int] = {}
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

257
        def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
258
259
260
261
262
263
264
265
266
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
267
268
            ip = item.ip
            return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
269
270
271
272

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
273
274
275
276
277
278
279
280
281
282
283
        sorted_worker_metadata = sorted(worker_metadata,
                                        key=sort_by_driver_then_worker_ip)
        start_rank = 0 if self.use_ray_spmd_worker else 1
        for i, item in enumerate(sorted_worker_metadata):
            item.adjusted_rank = i + start_rank
        self.workers = [item.worker for item in sorted_worker_metadata]
        rerank_mapping = {
            item.created_rank: item.adjusted_rank
            for item in sorted_worker_metadata
        }
        self._run_workers("adjust_rank", rerank_mapping)
284

285
        # Get the set of GPU IDs used on each node.
286
287
288
289
290
291
292
293
        worker_node_and_gpu_ids = []
        for worker in [self.driver_dummy_worker] + self.workers:
            if worker is None:
                # driver_dummy_worker can be None when using ray spmd worker.
                continue
            worker_node_and_gpu_ids.append(
                ray.get(worker.get_node_and_gpu_ids.remote()) \
            ) # type: ignore
294

295
296
297
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

298
299
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
300
301
302
303
304
305
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
306
307
308
309
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

310
311
312
313
314
315
316
317
318
        all_ips = set(worker_ips + [driver_ip])
        n_ips = len(all_ips)
        n_nodes = len(node_workers)

        if n_nodes != n_ips:
            raise RuntimeError(
                f"Every node should have a unique IP address. Got {n_nodes}"
                f" nodes with node ids {list(node_workers.keys())} and "
                f"{n_ips} unique IP addresses {all_ips}. Please check your"
319
320
                " network configuration. If you set `VLLM_HOST_IP`"
                " environment variable, make sure it is unique for"
321
322
                " each node.")

323
        # Set environment variables for the driver and workers.
324
325
        all_args_to_update_environment_variables = [{
            current_platform.device_control_env_var:
326
            ",".join(map(str, node_gpus[node_id])),
327
328
        } for (node_id, _) in worker_node_and_gpu_ids]

329
        # Environment variables to copy from driver to workers
330
331
        env_vars_to_copy = get_env_vars_to_copy(
            exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
332
333
            additional_vars=set(current_platform.additional_env_vars).union(
                self.ADDITIONAL_ENV_VARS),
334
            destination="workers")
335

336
        # Copy existing env vars to each worker's args
337
338
        for args in all_args_to_update_environment_variables:
            # TODO: refactor platform-specific env vars
339
            for name in env_vars_to_copy:
340
341
                if name in os.environ:
                    args[name] = os.environ[name]
342
343
344
345

        self._env_vars_for_all_workers = (
            all_args_to_update_environment_variables)

346
        self._run_workers("update_environment_variables",
347
                          self._get_env_vars_to_be_updated())
348

349
350
351
352
353
354
355
356
357
358
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
359
360
361
        distributed_init_method = get_distributed_init_method(
            driver_ip, get_open_port())

362
        # Initialize the actual workers inside worker wrapper.
363
364
365
366
367
368
        all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
            local_rank = node_workers[node_id].index(rank)
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
369
370
                rank=rank,
                distributed_init_method=distributed_init_method,
371
372
373
374
375
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
        self._run_workers("init_worker", all_kwargs)
376

377
        self._run_workers("init_device")
378
379
380
        self._run_workers("load_model",
                          max_concurrent_workers=self.parallel_config.
                          max_parallel_loading_workers)
381

382
383
384
385
386
387
388
389
390
391
392
393
394
        if self.use_ray_spmd_worker:
            for pp_rank in range(self.parallel_config.pipeline_parallel_size):
                self.pp_tp_workers.append([])
                for tp_rank in range(
                        self.parallel_config.tensor_parallel_size):
                    # PP=2, TP=4
                    # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
                    rank = (pp_rank * self.parallel_config.tensor_parallel_size
                            ) + tp_rank
                    assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                    assert pp_rank < len(self.pp_tp_workers)
                    self.pp_tp_workers[pp_rank].append(self.workers[rank])

395
396
397
398
399
400
401
402
403
        # This is the list of workers that are rank 0 of each TP group EXCEPT
        # global rank 0. These are the workers that will broadcast to the
        # rest of the workers.
        self.tp_driver_workers: List[RayWorkerWrapper] = []
        # This is the list of workers that are not drivers and not the first
        # worker in a TP group. These are the workers that will be
        # broadcasted to.
        self.non_driver_workers: List[RayWorkerWrapper] = []

404
        # Enforce rank order for correct rank to return final output.
405
406
407
        for index, worker in enumerate(self.workers):
            # The driver worker is rank 0 and not in self.workers.
            rank = index + 1
408
            if rank % self.parallel_config.tensor_parallel_size == 0:
409
                self.tp_driver_workers.append(worker)
410
            else:
411
                self.non_driver_workers.append(worker)
412

413
    def _driver_execute_model(
414
415
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
416
        """Run execute_model in the driver worker.
417

418
419
420
        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
421
422
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
423
424
        return self.driver_worker.execute_method("execute_model",
                                                 execute_model_req)
425

426
427
428
429
430
431
432
433
434
    def execute_model(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return super().execute_model(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)

435
436
437
438
        if self.use_v1:
            serialized_data = execute_model_req
        else:
            serialized_data = self.input_encoder.encode(execute_model_req)
439
        outputs = ray.get(self.forward_dag.execute(serialized_data))
440
441
442
443
        if self.use_v1:
            output = outputs[0]
        else:
            output = self.output_decoder.decode(outputs[0])
444
        return output
445

446
447
    def _run_workers(
        self,
448
        method: Union[str, Callable],
449
        *args,
450
        async_run_tensor_parallel_workers_only: bool = False,
451
452
453
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
454
455
456
        """Runs the given method on all workers. Can be used in the following
        ways:

457
458
459
460
461
        Args:
        - async_run_tensor_parallel_workers_only: If True the method will be
          run only in the remote TP workers, not the driver worker.
          It will also be run asynchronously and return a list of futures
          rather than blocking on the results.
462
        - args/kwargs: All workers share the same args/kwargs
463
        """
464
465
466
467
468
        if isinstance(method, str):
            sent_method = method
        else:
            sent_method = cloudpickle.dumps(method)
        del method
469
470
471
472
        if self.use_ray_spmd_worker:
            assert not async_run_tensor_parallel_workers_only, (
                "async_run_tensor_parallel_workers_only is not supported for "
                "spmd mode.")
473
474
475
476
477

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

478
479
480
481
482
        # Start the ray workers first.
        ray_workers = self.workers
        if async_run_tensor_parallel_workers_only:
            ray_workers = self.non_driver_workers
        ray_worker_outputs = [
483
            worker.execute_method.remote(sent_method, *args, **kwargs)
484
            for worker in ray_workers
485
        ]
486

487
        if async_run_tensor_parallel_workers_only:
488
489
490
            # Just return futures
            return ray_worker_outputs

491
492
493
494
495
496
        driver_worker_output = []
        # In SPMD mode, the driver worker is the same as any other worker,
        # so we only explicitly execute on the driver worker if using a
        # non-SPMD worker class.
        if not self.use_ray_spmd_worker:
            # Start the driver worker after all the ray workers.
497
            driver_worker_output = [
498
                self.driver_worker.execute_method(sent_method, *args, **kwargs)
499
            ]
500

501
502
        # Get the results of the ray workers.
        if self.workers:
503
            ray_worker_outputs = ray.get(ray_worker_outputs)
504

505
        return driver_worker_output + ray_worker_outputs
506

507
508
509
510
511
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        ray.get(parallel_worker_tasks)

512
    def _check_ray_cgraph_installation(self):
513
514
        import importlib.metadata

515
516
        from packaging import version

Rui Qiao's avatar
Rui Qiao committed
517
        required_version = version.parse("2.43.0")
518
        current_version = version.parse(importlib.metadata.version("ray"))
519
        if current_version < required_version:
520
            raise ValueError(f"Ray version {required_version} is "
521
522
                             f"required, but found {current_version}")

523
        import importlib.util
524
        cgraph_spec = importlib.util.find_spec(
525
            "ray.experimental.compiled_dag_ref")
526
527
        if cgraph_spec is None:
            raise ValueError("Ray Compiled Graph is not installed. "
Rui Qiao's avatar
Rui Qiao committed
528
                             "Run `pip install ray[cgraph]` to install it.")
529
530

        cupy_spec = importlib.util.find_spec("cupy")
531
532
        if (cupy_spec is None
                and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"):
533
534
            raise ValueError(
                "cupy is not installed but required since "
535
                "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
Rui Qiao's avatar
Rui Qiao committed
536
                "Run `pip install ray[cgraph]` and check cupy installation.")
537
538

    def _compiled_ray_dag(self, enable_asyncio: bool):
539
        assert self.parallel_config.use_ray
540
        self._check_ray_cgraph_installation()
541
542
543
544
545
546
547
548
        # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
        # (it is 10 seconds by default). This is a Ray environment variable to
        # control the timeout of getting result from a compiled graph execution,
        # i.e., the distributed execution that includes model forward runs and
        # intermediate tensor communications, in the case of vllm.
        # Note: we should set this env var before importing
        # ray.dag, otherwise it will not take effect.
        os.environ.setdefault("RAY_CGRAPH_get_timeout", "300")  # noqa: SIM112
549
        from ray.dag import InputNode, MultiOutputNode
550
551
        logger.info("RAY_CGRAPH_get_timeout is set to %s",
                    os.environ["RAY_CGRAPH_get_timeout"])  # noqa: SIM112
552
553
        logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
                    envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE)
554
555
        logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
                    envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
556
557
558
559
560
561
562

        channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
        if channel_type not in ("auto", "nccl", "shm"):
            raise ValueError(
                "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
                f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.")

563
        with InputNode() as input_data:
564
            # Example DAG: PP=2, TP=4
565
566
567
568
569
570
571
572
573
574
575
576
            #
            # For V0:
            # ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput   # noqa: E501
            #
            # For V1:
            # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput   # noqa: E501
577
578
579
580
581
582
583

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
584
585
                if self.use_v1:
                    outputs = [
586
                        worker.execute_model_ray.
587
588
589
590
591
592
593
594
595
                        bind(  # type: ignore[attr-defined]
                            outputs[i]) for i, worker in enumerate(tp_group)
                    ]
                else:
                    outputs = [
                        worker.execute_model_spmd.
                        bind(  # type: ignore[attr-defined]
                            outputs[i]) for i, worker in enumerate(tp_group)
                    ]
596
597

                last_pp_rank = len(self.pp_tp_workers) - 1
598
599
                if (pp_rank < last_pp_rank and
                        envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"):
600
601
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
602
603
                    # pp stage or when using shared memory (the default).
                    transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
604
                    outputs = [
Rui Qiao's avatar
Rui Qiao committed
605
                        output.with_tensor_transport(transport=transport)
606
607
608
609
610
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
        if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
            from ray.experimental.channel.accelerator_context import (
                register_accelerator_context)

            from vllm.distributed.device_communicators.ray_communicator import (
                RayPPCommunicator)
            register_accelerator_context(torch_module_name="cuda",
                                         communicator_cls=RayPPCommunicator)
            logger.info("Using RayPPCommunicator "
                        "(which wraps vLLM _PP GroupCoordinator) "
                        "for Ray Compiled Graph communication.")
        else:
            logger.info("Using Ray's NCCL communicator for "
                        "Ray Compiled Graph communication.")

626
627
628
629
        return forward_dag.experimental_compile(
            enable_asyncio=enable_asyncio,
            _overlap_gpu_communication=envs.
            VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
630
631

    def __del__(self):
632
        self.shutdown()
633

634
635
636
637
638
639
640
641
642
    async def execute_model_async(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return await super().execute_model_async(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)

643
644
        serialized_data = self.input_encoder.encode(execute_model_req)
        dag_future = await self.forward_dag.execute_async(serialized_data)
645
646
        output = await dag_future[0]
        return self.output_decoder.decode(output)
647

648
    async def _driver_execute_model_async(
649
        self,
650
651
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
652
653
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
654
655
656
        if not self.tp_driver_workers:
            return await self.driver_exec_method("execute_model",
                                                 execute_model_req)
657
658
659
660
661
662
663
664
665
        if self.pp_locks is None:
            # This locks each pipeline parallel stage so multiple virtual
            # engines can't execute on the same stage at the same time
            # We create the locks here to avoid creating them in the constructor
            # which uses a different asyncio loop.
            self.pp_locks = [
                asyncio.Lock()
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]
666

667
        tasks = [
668
669
            asyncio.create_task(
                _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
670
671
                                    "execute_model", execute_model_req))
        ]
672
673
674
675
676
677
678
679
680
681
682
683
        for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
                                                start=1):
            tasks.append(
                asyncio.create_task(
                    _run_task_with_lock(driver_worker.execute_method.remote,
                                        self.pp_locks[pp_rank],
                                        "execute_model", execute_model_req)))

        results = await asyncio.gather(*tasks)

        # Only the last PP stage has the final results.
        return results[-1]
684
685

    async def _start_worker_execution_loop(self):
686
687
        assert not self.use_ray_spmd_worker, (
            "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
688
689
        coros = [
            worker.execute_method.remote("start_worker_execution_loop")
690
            for worker in self.non_driver_workers
691
692
        ]
        return await asyncio.gather(*coros)
693

694
695
696
697
    def check_health(self) -> None:
        # Assume that the Ray workers are healthy.
        # TODO: check the health of the Ray workers
        return