ray_distributed_executor.py 27.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import asyncio
import os
5
from collections import defaultdict
6
from dataclasses import dataclass
7
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
8

9
import cloudpickle
10
11
import msgspec

12
import vllm.envs as envs
13
14
from vllm.executor.executor_base import (
    DistributedExecutorBase)  # yapf: disable
15
from vllm.executor.msgspec_utils import encode_hook
16
17
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
                                     ray)
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.sampler import SamplerOutput
20
from vllm.platforms import current_platform
21
from vllm.sequence import ExecuteModelRequest
22
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
23
                        get_ip, get_open_port, make_async)
24
25

if ray is not None:
26
    from ray.actor import ActorHandle
27
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
28
29
else:
    ActorHandle = None
30
31
32
33
34
35
36

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)


37
38
39
40
41
42
43
44
45
46
47
48
49
50
@dataclass
class RayWorkerMetaData:
    """
    Metadata for a Ray worker.
    The order of ray worker creation can be random,
    and we need to reset the rank after creating all workers.
    """
    worker: ActorHandle
    created_rank: int
    adjusted_rank: int = -1
    ip: str = ""


class RayDistributedExecutor(DistributedExecutorBase):
51

52
53
    uses_ray: bool = True

54
    def _init_executor(self) -> None:
55
        self.forward_dag: Optional[ray.dag.CompiledDAG] = None
56
57
58
59
        if envs.VLLM_USE_V1:
            # v1 always uses the compiled DAG and SPMD worker.
            os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
            os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        # If the env var is set, it uses the Ray's compiled DAG API
        # which optimizes the control plane overhead.
        # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
        # Currently, this requires USE_RAY_SPMD_WORKER=True.
        self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
        # If the env var is set, then we do not distinguish between the
        # "driver worker" vs other workers. Also, the rank 0 worker will
        # be executed in a remote Ray worker. Currently this requires
        # USE_RAY_COMPILED_DAG=True.
        self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
        if self.use_ray_compiled_dag:
            assert self.use_ray_spmd_worker, (
                "VLLM_USE_RAY_COMPILED_DAG=1 requires "
                "VLLM_USE_RAY_SPMD_WORKER=1")
        if self.use_ray_spmd_worker:
            # TODO: Support SPMD worker for non-DAG Ray executor.
            assert self.use_ray_compiled_dag, (
                "VLLM_USE_RAY_SPMD_WORKER=1 requires "
                "VLLM_USE_RAY_COMPILED_DAG=1")

80
        assert self.uses_ray
81
        initialize_ray_cluster(self.parallel_config)
82
83
84
85
86
87
88
89
90
91
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

92
93
94
        self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
        self.output_decoder = msgspec.msgpack.Decoder(
            Optional[List[SamplerOutput]])
95
96
97
98
99
100
101
        self.use_v1 = envs.VLLM_USE_V1

        self.pp_locks: Optional[List[asyncio.Lock]] = None
        self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
        if not self.use_ray_compiled_dag:
            self.driver_exec_method = make_async(
                self.driver_worker.execute_method)
102

103
    def shutdown(self) -> None:
104
105
106
107
        logger.info(
            "Shutting down Ray distributed executor. If you see error log "
            "from logging.cc regarding SIGTERM received, please ignore because "
            "this is the expected termination process in Ray.")
108
109
110
111
112
113
114
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    def _configure_ray_workers_use_nsight(self,
                                          ray_remote_kwargs) -> Dict[str, Any]:
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
        runtime_env.update({
            "nsight": {
                "t": "cuda,cudnn,cublas",
                "o": "'worker_process_%p'",
                "cuda-graph-trace": "node",
            }
        })

        return ray_remote_kwargs

130
131
132
133
    # child class could overwrite this to return actual env vars.
    def _get_env_vars_to_be_updated(self):
        return self._env_vars_for_all_workers

134
135
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
136
        num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
137
138
139

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
140
        self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
141
        # The remaining workers are the actual ray actors.
142
        self.workers: List[RayWorkerWrapper] = []
143

144
145
146
147
148
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
        self.pp_tp_workers: List[List[RayWorkerWrapper]] = []

149
150
151
152
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
                ray_remote_kwargs)

153
        logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
154

155
        # Create the workers.
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        bundle_indices: List[int]
        if envs.VLLM_RAY_BUNDLE_INDICES:
            # Use the bundle indices specified by the user.
            bundle_indices = list(
                map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
            assert len(bundle_indices) == self.parallel_config.world_size, \
            ("VLLM_RAY_BUNDLE_INDICES must have the same size"
            f" as the world size, but got {bundle_indices=} "
            f"and {self.parallel_config.world_size=}")
            assert len(set(bundle_indices)) == len(bundle_indices), \
            ("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
            f" but got {bundle_indices=}")
        else:
            # use the first N bundles that have GPU resources.
            bundle_indices = []
            for bundle_id, bundle in enumerate(placement_group.bundle_specs):
                if bundle.get(current_platform.ray_device_key, 0):
                    bundle_indices.append(bundle_id)
            bundle_indices = bundle_indices[:self.parallel_config.world_size]

176
        worker_metadata: List[RayWorkerMetaData] = []
177
178
        driver_ip = get_ip()
        for rank, bundle_id in enumerate(bundle_indices):
179
180
181
182
183
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
184

185
186
187
188
189
190
191
192
            if current_platform.ray_device_key == "GPU":
                # NV+AMD GPUs, and Intel XPUs
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
                )(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
193
                                           rpc_rank=rank)
194
195
196
197
198
199
200
201
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={current_platform.ray_device_key: num_gpus},
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
                )(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
202
                                           rpc_rank=rank)
203
204
205
206
207
208
209
210
211
212
            worker_metadata.append(
                RayWorkerMetaData(worker=worker, created_rank=rank))

        worker_ips = ray.get([
            each.worker.get_node_ip.remote()  # type: ignore[attr-defined]
            for each in worker_metadata
        ])

        for each, ip in zip(worker_metadata, worker_ips):
            each.ip = ip
213
214

        if not self.use_ray_spmd_worker:
215
216
217
218
            for i, each in enumerate(worker_metadata):
                # find and remove the dummy worker from the list
                worker = each.worker
                worker_ip = each.ip
219
                if self.driver_dummy_worker is None and worker_ip == driver_ip:
220
221
222
223
                    # If the worker is on the same node as the driver, we use it
                    # as the resource holder for the driver process.
                    self.driver_dummy_worker = worker
                    self.driver_worker = RayWorkerWrapper(
224
                        vllm_config=self.vllm_config, rpc_rank=0)
225
                    worker_metadata.pop(i)
226
                    break
227

228
        logger.debug("workers: %s", worker_metadata)
229
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
230
        if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
231
232
233
234
235
            raise ValueError(
                "Ray does not allocate any GPUs on the driver node. Consider "
                "adjusting the Ray placement group or running the driver on a "
                "GPU node.")

236
237
238
239
        ip_counts: Dict[str, int] = {}
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

240
        def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
241
242
243
244
245
246
247
248
249
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
250
251
            ip = item.ip
            return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
252
253
254
255

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
256
257
258
259
260
261
262
263
264
265
266
        sorted_worker_metadata = sorted(worker_metadata,
                                        key=sort_by_driver_then_worker_ip)
        start_rank = 0 if self.use_ray_spmd_worker else 1
        for i, item in enumerate(sorted_worker_metadata):
            item.adjusted_rank = i + start_rank
        self.workers = [item.worker for item in sorted_worker_metadata]
        rerank_mapping = {
            item.created_rank: item.adjusted_rank
            for item in sorted_worker_metadata
        }
        self._run_workers("adjust_rank", rerank_mapping)
267

268
        # Get the set of GPU IDs used on each node.
269
270
271
272
273
274
275
276
        worker_node_and_gpu_ids = []
        for worker in [self.driver_dummy_worker] + self.workers:
            if worker is None:
                # driver_dummy_worker can be None when using ray spmd worker.
                continue
            worker_node_and_gpu_ids.append(
                ray.get(worker.get_node_and_gpu_ids.remote()) \
            ) # type: ignore
277

278
279
280
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

281
282
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
283
284
285
286
287
288
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
289
290
291
292
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

293
294
295
296
297
298
299
300
301
        all_ips = set(worker_ips + [driver_ip])
        n_ips = len(all_ips)
        n_nodes = len(node_workers)

        if n_nodes != n_ips:
            raise RuntimeError(
                f"Every node should have a unique IP address. Got {n_nodes}"
                f" nodes with node ids {list(node_workers.keys())} and "
                f"{n_ips} unique IP addresses {all_ips}. Please check your"
302
303
                " network configuration. If you set `VLLM_HOST_IP`"
                " environment variable, make sure it is unique for"
304
305
                " each node.")

306
        # Set environment variables for the driver and workers.
307
308
        all_args_to_update_environment_variables = [{
            current_platform.device_control_env_var:
309
            ",".join(map(str, node_gpus[node_id])),
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        } for (node_id, _) in worker_node_and_gpu_ids]

        for args in all_args_to_update_environment_variables:
            # some carry-over env vars from the driver
            # TODO: refactor platform-specific env vars
            for name in [
                    "VLLM_ATTENTION_BACKEND",
                    "TPU_CHIPS_PER_HOST_BOUNDS",
                    "TPU_HOST_BOUNDS",
                    "VLLM_USE_V1",
                    "VLLM_TRACE_FUNCTION",
            ]:
                if name in os.environ:
                    args[name] = os.environ[name]
324
325
326
327

        self._env_vars_for_all_workers = (
            all_args_to_update_environment_variables)

328
        self._run_workers("update_environment_variables",
329
                          self._get_env_vars_to_be_updated())
330

331
332
333
334
335
336
337
338
339
340
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
341
342
343
        distributed_init_method = get_distributed_init_method(
            driver_ip, get_open_port())

344
        # Initialize the actual workers inside worker wrapper.
345
346
347
348
349
350
        all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
            local_rank = node_workers[node_id].index(rank)
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
351
352
                rank=rank,
                distributed_init_method=distributed_init_method,
353
354
355
356
357
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
        self._run_workers("init_worker", all_kwargs)
358

359
        self._run_workers("init_device")
360
361
362
        self._run_workers("load_model",
                          max_concurrent_workers=self.parallel_config.
                          max_parallel_loading_workers)
363

364
365
366
367
368
369
370
371
372
373
374
375
376
        if self.use_ray_spmd_worker:
            for pp_rank in range(self.parallel_config.pipeline_parallel_size):
                self.pp_tp_workers.append([])
                for tp_rank in range(
                        self.parallel_config.tensor_parallel_size):
                    # PP=2, TP=4
                    # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
                    rank = (pp_rank * self.parallel_config.tensor_parallel_size
                            ) + tp_rank
                    assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                    assert pp_rank < len(self.pp_tp_workers)
                    self.pp_tp_workers[pp_rank].append(self.workers[rank])

377
378
379
380
381
382
383
384
385
        # This is the list of workers that are rank 0 of each TP group EXCEPT
        # global rank 0. These are the workers that will broadcast to the
        # rest of the workers.
        self.tp_driver_workers: List[RayWorkerWrapper] = []
        # This is the list of workers that are not drivers and not the first
        # worker in a TP group. These are the workers that will be
        # broadcasted to.
        self.non_driver_workers: List[RayWorkerWrapper] = []

386
        # Enforce rank order for correct rank to return final output.
387
388
389
        for index, worker in enumerate(self.workers):
            # The driver worker is rank 0 and not in self.workers.
            rank = index + 1
390
            if rank % self.parallel_config.tensor_parallel_size == 0:
391
                self.tp_driver_workers.append(worker)
392
            else:
393
                self.non_driver_workers.append(worker)
394

395
    def _driver_execute_model(
396
397
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
398
        """Run execute_model in the driver worker.
399

400
401
402
        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
403
404
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
405
406
        return self.driver_worker.execute_method("execute_model",
                                                 execute_model_req)
407

408
409
410
411
412
413
414
415
416
    def execute_model(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return super().execute_model(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)

417
418
419
420
        if self.use_v1:
            serialized_data = execute_model_req
        else:
            serialized_data = self.input_encoder.encode(execute_model_req)
421
        outputs = ray.get(self.forward_dag.execute(serialized_data))
422
423
424
425
        if self.use_v1:
            output = outputs[0]
        else:
            output = self.output_decoder.decode(outputs[0])
426
        return output
427

428
429
    def _run_workers(
        self,
430
        method: Union[str, Callable],
431
        *args,
432
        async_run_tensor_parallel_workers_only: bool = False,
433
434
435
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
436
437
438
        """Runs the given method on all workers. Can be used in the following
        ways:

439
440
441
442
443
        Args:
        - async_run_tensor_parallel_workers_only: If True the method will be
          run only in the remote TP workers, not the driver worker.
          It will also be run asynchronously and return a list of futures
          rather than blocking on the results.
444
        - args/kwargs: All workers share the same args/kwargs
445
        """
446
447
448
449
450
        if isinstance(method, str):
            sent_method = method
        else:
            sent_method = cloudpickle.dumps(method)
        del method
451
452
453
454
        if self.use_ray_spmd_worker:
            assert not async_run_tensor_parallel_workers_only, (
                "async_run_tensor_parallel_workers_only is not supported for "
                "spmd mode.")
455
456
457
458
459

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

460
461
462
463
464
        # Start the ray workers first.
        ray_workers = self.workers
        if async_run_tensor_parallel_workers_only:
            ray_workers = self.non_driver_workers
        ray_worker_outputs = [
465
            worker.execute_method.remote(sent_method, *args, **kwargs)
466
            for worker in ray_workers
467
        ]
468

469
        if async_run_tensor_parallel_workers_only:
470
471
472
            # Just return futures
            return ray_worker_outputs

473
474
475
476
477
478
        driver_worker_output = []
        # In SPMD mode, the driver worker is the same as any other worker,
        # so we only explicitly execute on the driver worker if using a
        # non-SPMD worker class.
        if not self.use_ray_spmd_worker:
            # Start the driver worker after all the ray workers.
479
            driver_worker_output = [
480
                self.driver_worker.execute_method(sent_method, *args, **kwargs)
481
            ]
482

483
484
        # Get the results of the ray workers.
        if self.workers:
485
            ray_worker_outputs = ray.get(ray_worker_outputs)
486

487
        return driver_worker_output + ray_worker_outputs
488

489
490
491
492
493
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        ray.get(parallel_worker_tasks)

494
    def _check_ray_adag_installation(self):
495
        import pkg_resources
496
497
        from packaging import version

498
        required_version = version.parse("2.40")
499
500
        current_version = version.parse(
            pkg_resources.get_distribution("ray").version)
501
        if current_version < required_version:
502
            raise ValueError(f"Ray version {required_version} is "
503
504
                             f"required, but found {current_version}")

505
506
507
508
509
510
511
512
513
514
515
        import importlib.util
        adag_spec = importlib.util.find_spec(
            "ray.experimental.compiled_dag_ref")
        if adag_spec is None:
            raise ValueError("Ray accelerated DAG is not installed. "
                             "Run `pip install ray[adag]` to install it.")

        cupy_spec = importlib.util.find_spec("cupy")
        if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
            raise ValueError(
                "cupy is not installed but required since "
516
                "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set. "
517
518
519
                "Run `pip install ray[adag]` and check cupy installation.")

    def _compiled_ray_dag(self, enable_asyncio: bool):
520
        assert self.parallel_config.use_ray
521
        self._check_ray_adag_installation()
522
523
        from ray.dag import InputNode, MultiOutputNode
        from ray.experimental.channel.torch_tensor_type import TorchTensorType
524

525
526
        logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s",
                    envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
527
528
        logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
                    envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
529
        with InputNode() as input_data:
530
            # Example DAG: PP=2, TP=4
531
532
533
534
535
536
537
538
539
540
541
542
            #
            # For V0:
            # ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput   # noqa: E501
            #
            # For V1:
            # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput   # noqa: E501
543
544
545
546
547
548
549

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
550
551
                if self.use_v1:
                    outputs = [
552
                        worker.execute_model_ray.
553
554
555
556
557
558
559
560
561
                        bind(  # type: ignore[attr-defined]
                            outputs[i]) for i, worker in enumerate(tp_group)
                    ]
                else:
                    outputs = [
                        worker.execute_model_spmd.
                        bind(  # type: ignore[attr-defined]
                            outputs[i]) for i, worker in enumerate(tp_group)
                    ]
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578

                last_pp_rank = len(self.pp_tp_workers) - 1
                if pp_rank < last_pp_rank:
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
                    # pp stage.
                    transport = "nccl" \
                        if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
                        else "auto"
                    outputs = [
                        output.with_type_hint(
                            TorchTensorType(transport=transport))
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

579
580
581
582
        return forward_dag.experimental_compile(
            enable_asyncio=enable_asyncio,
            _overlap_gpu_communication=envs.
            VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
583
584

    def __del__(self):
585
        self.shutdown()
586

587
588
589
590
591
592
593
594
595
    async def execute_model_async(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return await super().execute_model_async(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)

596
597
        serialized_data = self.input_encoder.encode(execute_model_req)
        dag_future = await self.forward_dag.execute_async(serialized_data)
598
599
        output = await dag_future[0]
        return self.output_decoder.decode(output)
600

601
    async def _driver_execute_model_async(
602
        self,
603
604
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
605
606
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
607
608
609
        if not self.tp_driver_workers:
            return await self.driver_exec_method("execute_model",
                                                 execute_model_req)
610
611
612
613
614
615
616
617
618
        if self.pp_locks is None:
            # This locks each pipeline parallel stage so multiple virtual
            # engines can't execute on the same stage at the same time
            # We create the locks here to avoid creating them in the constructor
            # which uses a different asyncio loop.
            self.pp_locks = [
                asyncio.Lock()
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]
619

620
        tasks = [
621
622
            asyncio.create_task(
                _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
623
624
                                    "execute_model", execute_model_req))
        ]
625
626
627
628
629
630
631
632
633
634
635
636
        for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
                                                start=1):
            tasks.append(
                asyncio.create_task(
                    _run_task_with_lock(driver_worker.execute_method.remote,
                                        self.pp_locks[pp_rank],
                                        "execute_model", execute_model_req)))

        results = await asyncio.gather(*tasks)

        # Only the last PP stage has the final results.
        return results[-1]
637
638

    async def _start_worker_execution_loop(self):
639
640
        assert not self.use_ray_spmd_worker, (
            "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
641
642
        coros = [
            worker.execute_method.remote("start_worker_execution_loop")
643
            for worker in self.non_driver_workers
644
645
        ]
        return await asyncio.gather(*coros)
646

647
648
649
650
    def check_health(self) -> None:
        # Assume that the Ray workers are healthy.
        # TODO: check the health of the Ray workers
        return