ray_distributed_executor.py 26.3 KB
Newer Older
1
2
import asyncio
import os
3
from collections import defaultdict
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
6

7
import cloudpickle
8
9
import msgspec

10
import vllm.envs as envs
11
12
from vllm.executor.executor_base import (
    DistributedExecutorBase)  # yapf: disable
13
from vllm.executor.msgspec_utils import encode_hook
14
15
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
                                     ray)
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.sampler import SamplerOutput
18
from vllm.platforms import current_platform
19
from vllm.sequence import ExecuteModelRequest
20
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
21
                        get_ip, get_open_port, make_async)
22
23

if ray is not None:
24
    from ray.actor import ActorHandle
25
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
26
27
else:
    ActorHandle = None
28
29
30
31
32
33
34

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)


35
36
37
38
39
40
41
42
43
44
45
46
47
48
@dataclass
class RayWorkerMetaData:
    """
    Metadata for a Ray worker.
    The order of ray worker creation can be random,
    and we need to reset the rank after creating all workers.
    """
    worker: ActorHandle
    created_rank: int
    adjusted_rank: int = -1
    ip: str = ""


class RayDistributedExecutor(DistributedExecutorBase):
49

50
51
    uses_ray: bool = True

52
    def _init_executor(self) -> None:
53
        self.forward_dag: Optional[ray.dag.CompiledDAG] = None
54
55
56
57
        if envs.VLLM_USE_V1:
            # v1 always uses the compiled DAG and SPMD worker.
            os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
            os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        # If the env var is set, it uses the Ray's compiled DAG API
        # which optimizes the control plane overhead.
        # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
        # Currently, this requires USE_RAY_SPMD_WORKER=True.
        self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
        # If the env var is set, then we do not distinguish between the
        # "driver worker" vs other workers. Also, the rank 0 worker will
        # be executed in a remote Ray worker. Currently this requires
        # USE_RAY_COMPILED_DAG=True.
        self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
        if self.use_ray_compiled_dag:
            assert self.use_ray_spmd_worker, (
                "VLLM_USE_RAY_COMPILED_DAG=1 requires "
                "VLLM_USE_RAY_SPMD_WORKER=1")
        if self.use_ray_spmd_worker:
            # TODO: Support SPMD worker for non-DAG Ray executor.
            assert self.use_ray_compiled_dag, (
                "VLLM_USE_RAY_SPMD_WORKER=1 requires "
                "VLLM_USE_RAY_COMPILED_DAG=1")

78
        assert self.uses_ray
79
        initialize_ray_cluster(self.parallel_config)
80
81
82
83
84
85
86
87
88
89
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

90
91
92
        self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
        self.output_decoder = msgspec.msgpack.Decoder(
            Optional[List[SamplerOutput]])
93
94
95
96
97
98
99
        self.use_v1 = envs.VLLM_USE_V1

        self.pp_locks: Optional[List[asyncio.Lock]] = None
        self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
        if not self.use_ray_compiled_dag:
            self.driver_exec_method = make_async(
                self.driver_worker.execute_method)
100

101
102
103
104
105
106
107
108
    def shutdown(self) -> None:
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    def _configure_ray_workers_use_nsight(self,
                                          ray_remote_kwargs) -> Dict[str, Any]:
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
        runtime_env.update({
            "nsight": {
                "t": "cuda,cudnn,cublas",
                "o": "'worker_process_%p'",
                "cuda-graph-trace": "node",
            }
        })

        return ray_remote_kwargs

124
125
126
127
    # child class could overwrite this to return actual env vars.
    def _get_env_vars_to_be_updated(self):
        return self._env_vars_for_all_workers

128
129
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
130
131
        if (self.parallel_config.tensor_parallel_size == 1
                and self.parallel_config.pipeline_parallel_size == 1):
132
133
134
135
136
137
138
139
            # For single GPU case, we use a ray worker with constrained memory.
            num_gpus = self.cache_config.gpu_memory_utilization
        else:
            # Otherwise, the ray workers are allocated with a full GPU.
            num_gpus = 1

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
140
        self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
141
        # The remaining workers are the actual ray actors.
142
        self.workers: List[RayWorkerWrapper] = []
143

144
145
146
147
148
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
        self.pp_tp_workers: List[List[RayWorkerWrapper]] = []

149
150
151
152
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
                ray_remote_kwargs)

153
        logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
154

155
156
        # Create the workers.
        driver_ip = get_ip()
157
158
        rank = 0
        worker_metadata: List[RayWorkerMetaData] = []
159
        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
160
            if not bundle.get(current_platform.ray_device_key, 0):
161
162
163
164
165
166
                continue
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
167

168
169
170
171
172
173
174
175
            if current_platform.ray_device_key == "GPU":
                # NV+AMD GPUs, and Intel XPUs
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
                )(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
176
                                           rpc_rank=rank)
177
178
179
180
181
182
183
184
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={current_platform.ray_device_key: num_gpus},
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
                )(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
185
                                           rpc_rank=rank)
186
187
188
189
190
191
192
193
194
195
196
            worker_metadata.append(
                RayWorkerMetaData(worker=worker, created_rank=rank))
            rank += 1

        worker_ips = ray.get([
            each.worker.get_node_ip.remote()  # type: ignore[attr-defined]
            for each in worker_metadata
        ])

        for each, ip in zip(worker_metadata, worker_ips):
            each.ip = ip
197
198

        if not self.use_ray_spmd_worker:
199
200
201
202
            for i, each in enumerate(worker_metadata):
                # find and remove the dummy worker from the list
                worker = each.worker
                worker_ip = each.ip
203
                if self.driver_dummy_worker is None and worker_ip == driver_ip:
204
205
206
207
                    # If the worker is on the same node as the driver, we use it
                    # as the resource holder for the driver process.
                    self.driver_dummy_worker = worker
                    self.driver_worker = RayWorkerWrapper(
208
                        vllm_config=self.vllm_config, rpc_rank=0)
209
                    worker_metadata.pop(i)
210
                    break
211

212
        logger.debug("workers: %s", worker_metadata)
213
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
214
        if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
215
216
217
218
219
            raise ValueError(
                "Ray does not allocate any GPUs on the driver node. Consider "
                "adjusting the Ray placement group or running the driver on a "
                "GPU node.")

220
221
222
223
        ip_counts: Dict[str, int] = {}
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

224
        def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
225
226
227
228
229
230
231
232
233
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
234
235
            ip = item.ip
            return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
236
237
238
239

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
240
241
242
243
244
245
246
247
248
249
250
        sorted_worker_metadata = sorted(worker_metadata,
                                        key=sort_by_driver_then_worker_ip)
        start_rank = 0 if self.use_ray_spmd_worker else 1
        for i, item in enumerate(sorted_worker_metadata):
            item.adjusted_rank = i + start_rank
        self.workers = [item.worker for item in sorted_worker_metadata]
        rerank_mapping = {
            item.created_rank: item.adjusted_rank
            for item in sorted_worker_metadata
        }
        self._run_workers("adjust_rank", rerank_mapping)
251

252
        # Get the set of GPU IDs used on each node.
253
254
255
256
257
258
259
260
        worker_node_and_gpu_ids = []
        for worker in [self.driver_dummy_worker] + self.workers:
            if worker is None:
                # driver_dummy_worker can be None when using ray spmd worker.
                continue
            worker_node_and_gpu_ids.append(
                ray.get(worker.get_node_and_gpu_ids.remote()) \
            ) # type: ignore
261

262
263
264
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

265
266
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
267
268
269
270
271
272
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
273
274
275
276
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

277
278
279
280
281
282
283
284
285
        all_ips = set(worker_ips + [driver_ip])
        n_ips = len(all_ips)
        n_nodes = len(node_workers)

        if n_nodes != n_ips:
            raise RuntimeError(
                f"Every node should have a unique IP address. Got {n_nodes}"
                f" nodes with node ids {list(node_workers.keys())} and "
                f"{n_ips} unique IP addresses {all_ips}. Please check your"
286
287
                " network configuration. If you set `VLLM_HOST_IP`"
                " environment variable, make sure it is unique for"
288
289
                " each node.")

290
        # Set environment variables for the driver and workers.
291
292
        all_args_to_update_environment_variables = [{
            current_platform.device_control_env_var:
293
            ",".join(map(str, node_gpus[node_id])),
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        } for (node_id, _) in worker_node_and_gpu_ids]

        for args in all_args_to_update_environment_variables:
            # some carry-over env vars from the driver
            # TODO: refactor platform-specific env vars
            for name in [
                    "VLLM_ATTENTION_BACKEND",
                    "TPU_CHIPS_PER_HOST_BOUNDS",
                    "TPU_HOST_BOUNDS",
                    "VLLM_USE_V1",
                    "VLLM_TRACE_FUNCTION",
            ]:
                if name in os.environ:
                    args[name] = os.environ[name]
308
309
310
311

        self._env_vars_for_all_workers = (
            all_args_to_update_environment_variables)

312
        self._run_workers("update_environment_variables",
313
                          self._get_env_vars_to_be_updated())
314

315
316
317
318
319
320
321
322
323
324
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
325
326
327
        distributed_init_method = get_distributed_init_method(
            driver_ip, get_open_port())

328
        # Initialize the actual workers inside worker wrapper.
329
330
331
332
333
334
        all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
            local_rank = node_workers[node_id].index(rank)
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
335
336
                rank=rank,
                distributed_init_method=distributed_init_method,
337
338
339
340
341
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
        self._run_workers("init_worker", all_kwargs)
342

343
        self._run_workers("init_device")
344
345
346
        self._run_workers("load_model",
                          max_concurrent_workers=self.parallel_config.
                          max_parallel_loading_workers)
347

348
349
350
351
352
353
354
355
356
357
358
359
360
        if self.use_ray_spmd_worker:
            for pp_rank in range(self.parallel_config.pipeline_parallel_size):
                self.pp_tp_workers.append([])
                for tp_rank in range(
                        self.parallel_config.tensor_parallel_size):
                    # PP=2, TP=4
                    # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
                    rank = (pp_rank * self.parallel_config.tensor_parallel_size
                            ) + tp_rank
                    assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                    assert pp_rank < len(self.pp_tp_workers)
                    self.pp_tp_workers[pp_rank].append(self.workers[rank])

361
362
363
364
365
366
367
368
369
        # This is the list of workers that are rank 0 of each TP group EXCEPT
        # global rank 0. These are the workers that will broadcast to the
        # rest of the workers.
        self.tp_driver_workers: List[RayWorkerWrapper] = []
        # This is the list of workers that are not drivers and not the first
        # worker in a TP group. These are the workers that will be
        # broadcasted to.
        self.non_driver_workers: List[RayWorkerWrapper] = []

370
        # Enforce rank order for correct rank to return final output.
371
372
373
        for index, worker in enumerate(self.workers):
            # The driver worker is rank 0 and not in self.workers.
            rank = index + 1
374
            if rank % self.parallel_config.tensor_parallel_size == 0:
375
                self.tp_driver_workers.append(worker)
376
            else:
377
                self.non_driver_workers.append(worker)
378

379
    def _driver_execute_model(
380
381
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
382
        """Run execute_model in the driver worker.
383

384
385
386
        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
387
388
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
389
390
        return self.driver_worker.execute_method("execute_model",
                                                 execute_model_req)
391

392
393
394
395
396
397
398
399
400
    def execute_model(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return super().execute_model(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)

401
402
403
404
        if self.use_v1:
            serialized_data = execute_model_req
        else:
            serialized_data = self.input_encoder.encode(execute_model_req)
405
        outputs = ray.get(self.forward_dag.execute(serialized_data))
406
407
408
409
        if self.use_v1:
            output = outputs[0]
        else:
            output = self.output_decoder.decode(outputs[0])
410
        return output
411

412
413
    def _run_workers(
        self,
414
        method: Union[str, Callable],
415
        *args,
416
        async_run_tensor_parallel_workers_only: bool = False,
417
418
419
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
420
421
422
        """Runs the given method on all workers. Can be used in the following
        ways:

423
424
425
426
427
        Args:
        - async_run_tensor_parallel_workers_only: If True the method will be
          run only in the remote TP workers, not the driver worker.
          It will also be run asynchronously and return a list of futures
          rather than blocking on the results.
428
        - args/kwargs: All workers share the same args/kwargs
429
        """
430
431
432
433
434
        if isinstance(method, str):
            sent_method = method
        else:
            sent_method = cloudpickle.dumps(method)
        del method
435
436
437
438
        if self.use_ray_spmd_worker:
            assert not async_run_tensor_parallel_workers_only, (
                "async_run_tensor_parallel_workers_only is not supported for "
                "spmd mode.")
439
440
441
442
443

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

444
445
446
447
448
        # Start the ray workers first.
        ray_workers = self.workers
        if async_run_tensor_parallel_workers_only:
            ray_workers = self.non_driver_workers
        ray_worker_outputs = [
449
            worker.execute_method.remote(sent_method, *args, **kwargs)
450
            for worker in ray_workers
451
        ]
452

453
        if async_run_tensor_parallel_workers_only:
454
455
456
            # Just return futures
            return ray_worker_outputs

457
458
459
460
461
462
        driver_worker_output = []
        # In SPMD mode, the driver worker is the same as any other worker,
        # so we only explicitly execute on the driver worker if using a
        # non-SPMD worker class.
        if not self.use_ray_spmd_worker:
            # Start the driver worker after all the ray workers.
463
            driver_worker_output = [
464
                self.driver_worker.execute_method(sent_method, *args, **kwargs)
465
            ]
466

467
468
        # Get the results of the ray workers.
        if self.workers:
469
            ray_worker_outputs = ray.get(ray_worker_outputs)
470

471
        return driver_worker_output + ray_worker_outputs
472

473
474
475
476
477
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        ray.get(parallel_worker_tasks)

478
    def _check_ray_adag_installation(self):
479
        import pkg_resources
480
481
        from packaging import version

482
        required_version = version.parse("2.40")
483
484
        current_version = version.parse(
            pkg_resources.get_distribution("ray").version)
485
        if current_version < required_version:
486
            raise ValueError(f"Ray version {required_version} is "
487
488
                             f"required, but found {current_version}")

489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        import importlib.util
        adag_spec = importlib.util.find_spec(
            "ray.experimental.compiled_dag_ref")
        if adag_spec is None:
            raise ValueError("Ray accelerated DAG is not installed. "
                             "Run `pip install ray[adag]` to install it.")

        cupy_spec = importlib.util.find_spec("cupy")
        if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
            raise ValueError(
                "cupy is not installed but required since "
                "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
                "Run `pip install ray[adag]` and check cupy installation.")

    def _compiled_ray_dag(self, enable_asyncio: bool):
504
        assert self.parallel_config.use_ray
505
        self._check_ray_adag_installation()
506
507
        from ray.dag import InputNode, MultiOutputNode
        from ray.experimental.channel.torch_tensor_type import TorchTensorType
508

509
510
        logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s",
                    envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
511
512
        logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
                    envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
513
        with InputNode() as input_data:
514
515
516
517
518
519
520
521
522
523
524
525
            # Example DAG: PP=2, TP=4
            # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput   # noqa: E501
            #                         -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput   # noqa: E501
            #                         -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput   # noqa: E501
            #                         -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput   # noqa: E501

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
526
527
528
529
530
531
532
533
534
535
536
537
                if self.use_v1:
                    outputs = [
                        worker.execute_model.
                        bind(  # type: ignore[attr-defined]
                            outputs[i]) for i, worker in enumerate(tp_group)
                    ]
                else:
                    outputs = [
                        worker.execute_model_spmd.
                        bind(  # type: ignore[attr-defined]
                            outputs[i]) for i, worker in enumerate(tp_group)
                    ]
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554

                last_pp_rank = len(self.pp_tp_workers) - 1
                if pp_rank < last_pp_rank:
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
                    # pp stage.
                    transport = "nccl" \
                        if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
                        else "auto"
                    outputs = [
                        output.with_type_hint(
                            TorchTensorType(transport=transport))
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

555
556
557
558
        return forward_dag.experimental_compile(
            enable_asyncio=enable_asyncio,
            _overlap_gpu_communication=envs.
            VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
559
560

    def __del__(self):
561
        self.shutdown()
562

563
564
565
566
567
568
569
570
571
    async def execute_model_async(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return await super().execute_model_async(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)

572
573
        serialized_data = self.input_encoder.encode(execute_model_req)
        dag_future = await self.forward_dag.execute_async(serialized_data)
574
575
        output = await dag_future[0]
        return self.output_decoder.decode(output)
576

577
    async def _driver_execute_model_async(
578
        self,
579
580
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
581
582
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
583
584
585
        if not self.tp_driver_workers:
            return await self.driver_exec_method("execute_model",
                                                 execute_model_req)
586
587
588
589
590
591
592
593
594
        if self.pp_locks is None:
            # This locks each pipeline parallel stage so multiple virtual
            # engines can't execute on the same stage at the same time
            # We create the locks here to avoid creating them in the constructor
            # which uses a different asyncio loop.
            self.pp_locks = [
                asyncio.Lock()
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]
595

596
        tasks = [
597
598
            asyncio.create_task(
                _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
599
600
                                    "execute_model", execute_model_req))
        ]
601
602
603
604
605
606
607
608
609
610
611
612
        for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
                                                start=1):
            tasks.append(
                asyncio.create_task(
                    _run_task_with_lock(driver_worker.execute_method.remote,
                                        self.pp_locks[pp_rank],
                                        "execute_model", execute_model_req)))

        results = await asyncio.gather(*tasks)

        # Only the last PP stage has the final results.
        return results[-1]
613
614

    async def _start_worker_execution_loop(self):
615
616
        assert not self.use_ray_spmd_worker, (
            "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
617
618
        coros = [
            worker.execute_method.remote("start_worker_execution_loop")
619
            for worker in self.non_driver_workers
620
621
        ]
        return await asyncio.gather(*coros)
622

623
624
625
626
    def check_health(self) -> None:
        # Assume that the Ray workers are healthy.
        # TODO: check the health of the Ray workers
        return