ray_distributed_executor.py 29.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import os
6
from collections import defaultdict
7
from dataclasses import dataclass
8
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
9

10
import cloudpickle
11
12
import msgspec

13
import vllm.envs as envs
14
15
from vllm.executor.executor_base import (
    DistributedExecutorBase)  # yapf: disable
16
from vllm.executor.msgspec_utils import encode_hook
17
18
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
                                     ray)
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.sampler import SamplerOutput
21
from vllm.platforms import current_platform
22
from vllm.ray.ray_env import get_env_vars_to_copy
23
from vllm.sequence import ExecuteModelRequest
24
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
25
                        get_ip, get_open_port, make_async)
26
27

if ray is not None:
28
    from ray.actor import ActorHandle
29
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
30
31
else:
    ActorHandle = None
32
33
34
35
36
37
38

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)


39
40
41
42
43
44
45
46
47
48
49
50
51
52
@dataclass
class RayWorkerMetaData:
    """
    Metadata for a Ray worker.
    The order of ray worker creation can be random,
    and we need to reset the rank after creating all workers.
    """
    worker: ActorHandle
    created_rank: int
    adjusted_rank: int = -1
    ip: str = ""


class RayDistributedExecutor(DistributedExecutorBase):
53
54
55
56
57
58
59
60
    """Ray-based distributed executor"""

    # These env vars are worker-specific, therefore are NOT copied
    # from the driver to the workers
    WORKER_SPECIFIC_ENV_VARS = {
        "VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES"
    }

61
62
    uses_ray: bool = True

63
    def _init_executor(self) -> None:
64
        self.forward_dag: Optional[ray.dag.CompiledDAG] = None
65
        if envs.VLLM_USE_V1:
66
            # V1 uses SPMD worker and compiled DAG
67
68
            os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
            os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
69
70
71

            # For TPU, avoid compiling NVIDIA's NCCL
            if current_platform.is_tpu():
72
                os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
73

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        # If the env var is set, it uses the Ray's compiled DAG API
        # which optimizes the control plane overhead.
        # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
        # Currently, this requires USE_RAY_SPMD_WORKER=True.
        self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
        # If the env var is set, then we do not distinguish between the
        # "driver worker" vs other workers. Also, the rank 0 worker will
        # be executed in a remote Ray worker. Currently this requires
        # USE_RAY_COMPILED_DAG=True.
        self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
        if self.use_ray_compiled_dag:
            assert self.use_ray_spmd_worker, (
                "VLLM_USE_RAY_COMPILED_DAG=1 requires "
                "VLLM_USE_RAY_SPMD_WORKER=1")
        if self.use_ray_spmd_worker:
            # TODO: Support SPMD worker for non-DAG Ray executor.
            assert self.use_ray_compiled_dag, (
                "VLLM_USE_RAY_SPMD_WORKER=1 requires "
                "VLLM_USE_RAY_COMPILED_DAG=1")

94
        assert self.uses_ray
95
        initialize_ray_cluster(self.parallel_config)
96
97
98
99
100
101
102
103
104
105
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

106
107
108
        self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
        self.output_decoder = msgspec.msgpack.Decoder(
            Optional[List[SamplerOutput]])
109
110
111
112
113
114
        self.use_v1 = envs.VLLM_USE_V1

        self.pp_locks: Optional[List[asyncio.Lock]] = None
        if not self.use_ray_compiled_dag:
            self.driver_exec_method = make_async(
                self.driver_worker.execute_method)
115

116
    def shutdown(self) -> None:
117
118
119
120
        logger.info(
            "Shutting down Ray distributed executor. If you see error log "
            "from logging.cc regarding SIGTERM received, please ignore because "
            "this is the expected termination process in Ray.")
121
122
123
124
125
126
127
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    def _configure_ray_workers_use_nsight(self,
                                          ray_remote_kwargs) -> Dict[str, Any]:
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
        runtime_env.update({
            "nsight": {
                "t": "cuda,cudnn,cublas",
                "o": "'worker_process_%p'",
                "cuda-graph-trace": "node",
            }
        })

        return ray_remote_kwargs

143
144
145
146
    # child class could overwrite this to return actual env vars.
    def _get_env_vars_to_be_updated(self):
        return self._env_vars_for_all_workers

147
148
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
149
        num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
150
151
152

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
153
        self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
154
        # The remaining workers are the actual ray actors.
155
        self.workers: List[RayWorkerWrapper] = []
156

157
158
159
160
161
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
        self.pp_tp_workers: List[List[RayWorkerWrapper]] = []

162
163
164
165
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
                ray_remote_kwargs)

166
        logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
167

168
        # Create the workers.
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        bundle_indices: List[int]
        if envs.VLLM_RAY_BUNDLE_INDICES:
            # Use the bundle indices specified by the user.
            bundle_indices = list(
                map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
            assert len(bundle_indices) == self.parallel_config.world_size, \
            ("VLLM_RAY_BUNDLE_INDICES must have the same size"
            f" as the world size, but got {bundle_indices=} "
            f"and {self.parallel_config.world_size=}")
            assert len(set(bundle_indices)) == len(bundle_indices), \
            ("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
            f" but got {bundle_indices=}")
        else:
            # use the first N bundles that have GPU resources.
            bundle_indices = []
            for bundle_id, bundle in enumerate(placement_group.bundle_specs):
                if bundle.get(current_platform.ray_device_key, 0):
                    bundle_indices.append(bundle_id)
            bundle_indices = bundle_indices[:self.parallel_config.world_size]

189
        worker_metadata: List[RayWorkerMetaData] = []
190
191
        driver_ip = get_ip()
        for rank, bundle_id in enumerate(bundle_indices):
192
193
194
195
196
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
197

198
199
200
201
202
203
204
205
            if current_platform.ray_device_key == "GPU":
                # NV+AMD GPUs, and Intel XPUs
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
                )(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
206
                                           rpc_rank=rank)
207
208
209
210
211
212
213
214
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={current_platform.ray_device_key: num_gpus},
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
                )(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
215
                                           rpc_rank=rank)
216
217
218
219
220
221
222
223
224
225
            worker_metadata.append(
                RayWorkerMetaData(worker=worker, created_rank=rank))

        worker_ips = ray.get([
            each.worker.get_node_ip.remote()  # type: ignore[attr-defined]
            for each in worker_metadata
        ])

        for each, ip in zip(worker_metadata, worker_ips):
            each.ip = ip
226
227

        if not self.use_ray_spmd_worker:
228
229
230
231
            for i, each in enumerate(worker_metadata):
                # find and remove the dummy worker from the list
                worker = each.worker
                worker_ip = each.ip
232
                if self.driver_dummy_worker is None and worker_ip == driver_ip:
233
234
235
236
                    # If the worker is on the same node as the driver, we use it
                    # as the resource holder for the driver process.
                    self.driver_dummy_worker = worker
                    self.driver_worker = RayWorkerWrapper(
237
                        vllm_config=self.vllm_config, rpc_rank=0)
238
                    worker_metadata.pop(i)
239
                    break
240

241
        logger.debug("workers: %s", worker_metadata)
242
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
243
        if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
244
            raise ValueError(
245
246
247
248
                "Ray does not allocate any GPUs on the driver node."
                f"Driver IP: {driver_ip}, worker IPs: {worker_ips}."
                "Consider adjusting the Ray placement group or running "
                "the driver on a GPU node.")
249

250
251
252
253
        ip_counts: Dict[str, int] = {}
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

254
        def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
255
256
257
258
259
260
261
262
263
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
264
265
            ip = item.ip
            return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
266
267
268
269

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
270
271
272
273
274
275
276
277
278
279
280
        sorted_worker_metadata = sorted(worker_metadata,
                                        key=sort_by_driver_then_worker_ip)
        start_rank = 0 if self.use_ray_spmd_worker else 1
        for i, item in enumerate(sorted_worker_metadata):
            item.adjusted_rank = i + start_rank
        self.workers = [item.worker for item in sorted_worker_metadata]
        rerank_mapping = {
            item.created_rank: item.adjusted_rank
            for item in sorted_worker_metadata
        }
        self._run_workers("adjust_rank", rerank_mapping)
281

282
        # Get the set of GPU IDs used on each node.
283
284
285
286
287
288
289
290
        worker_node_and_gpu_ids = []
        for worker in [self.driver_dummy_worker] + self.workers:
            if worker is None:
                # driver_dummy_worker can be None when using ray spmd worker.
                continue
            worker_node_and_gpu_ids.append(
                ray.get(worker.get_node_and_gpu_ids.remote()) \
            ) # type: ignore
291

292
293
294
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

295
296
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
297
298
299
300
301
302
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
303
304
305
306
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

307
308
309
310
311
312
313
314
315
        all_ips = set(worker_ips + [driver_ip])
        n_ips = len(all_ips)
        n_nodes = len(node_workers)

        if n_nodes != n_ips:
            raise RuntimeError(
                f"Every node should have a unique IP address. Got {n_nodes}"
                f" nodes with node ids {list(node_workers.keys())} and "
                f"{n_ips} unique IP addresses {all_ips}. Please check your"
316
317
                " network configuration. If you set `VLLM_HOST_IP`"
                " environment variable, make sure it is unique for"
318
319
                " each node.")

320
        # Set environment variables for the driver and workers.
321
322
        all_args_to_update_environment_variables = [{
            current_platform.device_control_env_var:
323
            ",".join(map(str, node_gpus[node_id])),
324
325
        } for (node_id, _) in worker_node_and_gpu_ids]

326
        # Environment variables to copy from driver to workers
327
328
329
330
        env_vars_to_copy = get_env_vars_to_copy(
            exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
            additional_vars=set(current_platform.additional_env_vars),
            destination="workers")
331

332
        # Copy existing env vars to each worker's args
333
334
        for args in all_args_to_update_environment_variables:
            # TODO: refactor platform-specific env vars
335
            for name in env_vars_to_copy:
336
337
                if name in os.environ:
                    args[name] = os.environ[name]
338
339
340
341

        self._env_vars_for_all_workers = (
            all_args_to_update_environment_variables)

342
        self._run_workers("update_environment_variables",
343
                          self._get_env_vars_to_be_updated())
344

345
346
347
348
349
350
351
352
353
354
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
355
356
357
        distributed_init_method = get_distributed_init_method(
            driver_ip, get_open_port())

358
        # Initialize the actual workers inside worker wrapper.
359
360
361
362
363
364
        all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
            local_rank = node_workers[node_id].index(rank)
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
365
366
                rank=rank,
                distributed_init_method=distributed_init_method,
367
368
369
370
371
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
        self._run_workers("init_worker", all_kwargs)
372

373
        self._run_workers("init_device")
374
375
376
        self._run_workers("load_model",
                          max_concurrent_workers=self.parallel_config.
                          max_parallel_loading_workers)
377

378
379
380
381
382
383
384
385
386
387
388
389
390
        if self.use_ray_spmd_worker:
            for pp_rank in range(self.parallel_config.pipeline_parallel_size):
                self.pp_tp_workers.append([])
                for tp_rank in range(
                        self.parallel_config.tensor_parallel_size):
                    # PP=2, TP=4
                    # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
                    rank = (pp_rank * self.parallel_config.tensor_parallel_size
                            ) + tp_rank
                    assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                    assert pp_rank < len(self.pp_tp_workers)
                    self.pp_tp_workers[pp_rank].append(self.workers[rank])

391
392
393
394
395
396
397
398
399
        # This is the list of workers that are rank 0 of each TP group EXCEPT
        # global rank 0. These are the workers that will broadcast to the
        # rest of the workers.
        self.tp_driver_workers: List[RayWorkerWrapper] = []
        # This is the list of workers that are not drivers and not the first
        # worker in a TP group. These are the workers that will be
        # broadcasted to.
        self.non_driver_workers: List[RayWorkerWrapper] = []

400
        # Enforce rank order for correct rank to return final output.
401
402
403
        for index, worker in enumerate(self.workers):
            # The driver worker is rank 0 and not in self.workers.
            rank = index + 1
404
            if rank % self.parallel_config.tensor_parallel_size == 0:
405
                self.tp_driver_workers.append(worker)
406
            else:
407
                self.non_driver_workers.append(worker)
408

409
    def _driver_execute_model(
410
411
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
412
        """Run execute_model in the driver worker.
413

414
415
416
        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
417
418
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
419
420
        return self.driver_worker.execute_method("execute_model",
                                                 execute_model_req)
421

422
423
424
425
426
427
428
429
430
    def execute_model(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return super().execute_model(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)

431
432
433
434
        if self.use_v1:
            serialized_data = execute_model_req
        else:
            serialized_data = self.input_encoder.encode(execute_model_req)
435
        outputs = ray.get(self.forward_dag.execute(serialized_data))
436
437
438
439
        if self.use_v1:
            output = outputs[0]
        else:
            output = self.output_decoder.decode(outputs[0])
440
        return output
441

442
443
    def _run_workers(
        self,
444
        method: Union[str, Callable],
445
        *args,
446
        async_run_tensor_parallel_workers_only: bool = False,
447
448
449
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
450
451
452
        """Runs the given method on all workers. Can be used in the following
        ways:

453
454
455
456
457
        Args:
        - async_run_tensor_parallel_workers_only: If True the method will be
          run only in the remote TP workers, not the driver worker.
          It will also be run asynchronously and return a list of futures
          rather than blocking on the results.
458
        - args/kwargs: All workers share the same args/kwargs
459
        """
460
461
462
463
464
        if isinstance(method, str):
            sent_method = method
        else:
            sent_method = cloudpickle.dumps(method)
        del method
465
466
467
468
        if self.use_ray_spmd_worker:
            assert not async_run_tensor_parallel_workers_only, (
                "async_run_tensor_parallel_workers_only is not supported for "
                "spmd mode.")
469
470
471
472
473

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

474
475
476
477
478
        # Start the ray workers first.
        ray_workers = self.workers
        if async_run_tensor_parallel_workers_only:
            ray_workers = self.non_driver_workers
        ray_worker_outputs = [
479
            worker.execute_method.remote(sent_method, *args, **kwargs)
480
            for worker in ray_workers
481
        ]
482

483
        if async_run_tensor_parallel_workers_only:
484
485
486
            # Just return futures
            return ray_worker_outputs

487
488
489
490
491
492
        driver_worker_output = []
        # In SPMD mode, the driver worker is the same as any other worker,
        # so we only explicitly execute on the driver worker if using a
        # non-SPMD worker class.
        if not self.use_ray_spmd_worker:
            # Start the driver worker after all the ray workers.
493
            driver_worker_output = [
494
                self.driver_worker.execute_method(sent_method, *args, **kwargs)
495
            ]
496

497
498
        # Get the results of the ray workers.
        if self.workers:
499
            ray_worker_outputs = ray.get(ray_worker_outputs)
500

501
        return driver_worker_output + ray_worker_outputs
502

503
504
505
506
507
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        ray.get(parallel_worker_tasks)

508
    def _check_ray_cgraph_installation(self):
509
510
        import importlib.metadata

511
512
        from packaging import version

Rui Qiao's avatar
Rui Qiao committed
513
        required_version = version.parse("2.43.0")
514
        current_version = version.parse(importlib.metadata.version("ray"))
515
        if current_version < required_version:
516
            raise ValueError(f"Ray version {required_version} is "
517
518
                             f"required, but found {current_version}")

519
        import importlib.util
520
        cgraph_spec = importlib.util.find_spec(
521
            "ray.experimental.compiled_dag_ref")
522
523
        if cgraph_spec is None:
            raise ValueError("Ray Compiled Graph is not installed. "
Rui Qiao's avatar
Rui Qiao committed
524
                             "Run `pip install ray[cgraph]` to install it.")
525
526

        cupy_spec = importlib.util.find_spec("cupy")
527
528
        if (cupy_spec is None
                and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"):
529
530
            raise ValueError(
                "cupy is not installed but required since "
531
                "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
Rui Qiao's avatar
Rui Qiao committed
532
                "Run `pip install ray[cgraph]` and check cupy installation.")
533
534

    def _compiled_ray_dag(self, enable_asyncio: bool):
535
        assert self.parallel_config.use_ray
536
        self._check_ray_cgraph_installation()
537
538
539
540
541
542
543
544
        # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
        # (it is 10 seconds by default). This is a Ray environment variable to
        # control the timeout of getting result from a compiled graph execution,
        # i.e., the distributed execution that includes model forward runs and
        # intermediate tensor communications, in the case of vllm.
        # Note: we should set this env var before importing
        # ray.dag, otherwise it will not take effect.
        os.environ.setdefault("RAY_CGRAPH_get_timeout", "300")  # noqa: SIM112
545
        from ray.dag import InputNode, MultiOutputNode
546
547
        logger.info("RAY_CGRAPH_get_timeout is set to %s",
                    os.environ["RAY_CGRAPH_get_timeout"])  # noqa: SIM112
548
549
        logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
                    envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE)
550
551
        logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
                    envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
552
553
554
555
556
557
558

        channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
        if channel_type not in ("auto", "nccl", "shm"):
            raise ValueError(
                "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
                f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.")

559
        with InputNode() as input_data:
560
            # Example DAG: PP=2, TP=4
561
562
563
564
565
566
567
568
569
570
571
572
            #
            # For V0:
            # ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput   # noqa: E501
            #
            # For V1:
            # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput   # noqa: E501
573
574
575
576
577
578
579

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
580
581
                if self.use_v1:
                    outputs = [
582
                        worker.execute_model_ray.
583
584
585
586
587
588
589
590
591
                        bind(  # type: ignore[attr-defined]
                            outputs[i]) for i, worker in enumerate(tp_group)
                    ]
                else:
                    outputs = [
                        worker.execute_model_spmd.
                        bind(  # type: ignore[attr-defined]
                            outputs[i]) for i, worker in enumerate(tp_group)
                    ]
592
593

                last_pp_rank = len(self.pp_tp_workers) - 1
594
595
                if (pp_rank < last_pp_rank and
                        envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"):
596
597
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
598
599
                    # pp stage or when using shared memory (the default).
                    transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
600
                    outputs = [
Rui Qiao's avatar
Rui Qiao committed
601
                        output.with_tensor_transport(transport=transport)
602
603
604
605
606
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

607
608
609
610
        return forward_dag.experimental_compile(
            enable_asyncio=enable_asyncio,
            _overlap_gpu_communication=envs.
            VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
611
612

    def __del__(self):
613
        self.shutdown()
614

615
616
617
618
619
620
621
622
623
    async def execute_model_async(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return await super().execute_model_async(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)

624
625
        serialized_data = self.input_encoder.encode(execute_model_req)
        dag_future = await self.forward_dag.execute_async(serialized_data)
626
627
        output = await dag_future[0]
        return self.output_decoder.decode(output)
628

629
    async def _driver_execute_model_async(
630
        self,
631
632
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
633
634
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
635
636
637
        if not self.tp_driver_workers:
            return await self.driver_exec_method("execute_model",
                                                 execute_model_req)
638
639
640
641
642
643
644
645
646
        if self.pp_locks is None:
            # This locks each pipeline parallel stage so multiple virtual
            # engines can't execute on the same stage at the same time
            # We create the locks here to avoid creating them in the constructor
            # which uses a different asyncio loop.
            self.pp_locks = [
                asyncio.Lock()
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]
647

648
        tasks = [
649
650
            asyncio.create_task(
                _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
651
652
                                    "execute_model", execute_model_req))
        ]
653
654
655
656
657
658
659
660
661
662
663
664
        for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
                                                start=1):
            tasks.append(
                asyncio.create_task(
                    _run_task_with_lock(driver_worker.execute_method.remote,
                                        self.pp_locks[pp_rank],
                                        "execute_model", execute_model_req)))

        results = await asyncio.gather(*tasks)

        # Only the last PP stage has the final results.
        return results[-1]
665
666

    async def _start_worker_execution_loop(self):
667
668
        assert not self.use_ray_spmd_worker, (
            "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
669
670
        coros = [
            worker.execute_method.remote("start_worker_execution_loop")
671
            for worker in self.non_driver_workers
672
673
        ]
        return await asyncio.gather(*coros)
674

675
676
677
678
    def check_health(self) -> None:
        # Assume that the Ray workers are healthy.
        # TODO: check the health of the Ray workers
        return