ray_distributed_executor.py 29 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
import json
5
import os
6
from collections import defaultdict
7
from dataclasses import dataclass
8
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
9

10
import cloudpickle
11
12
import msgspec

13
import vllm.envs as envs
14
15
from vllm.executor.executor_base import (
    DistributedExecutorBase)  # yapf: disable
16
from vllm.executor.msgspec_utils import encode_hook
17
18
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
                                     ray)
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.sampler import SamplerOutput
21
from vllm.platforms import current_platform
22
from vllm.sequence import ExecuteModelRequest
23
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
24
                        get_ip, get_open_port, make_async)
25
26

if ray is not None:
27
    from ray.actor import ActorHandle
28
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
29
30
else:
    ActorHandle = None
31
32
33
34
35
36
37

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)


38
39
40
41
42
43
44
45
46
47
48
49
50
51
@dataclass
class RayWorkerMetaData:
    """
    Metadata for a Ray worker.
    The order of ray worker creation can be random,
    and we need to reset the rank after creating all workers.
    """
    worker: ActorHandle
    created_rank: int
    adjusted_rank: int = -1
    ip: str = ""


class RayDistributedExecutor(DistributedExecutorBase):
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    """Ray-based distributed executor"""

    # These env vars are worker-specific, therefore are NOT copied
    # from the driver to the workers
    WORKER_SPECIFIC_ENV_VARS = {
        "VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES"
    }

    config_home = envs.VLLM_CONFIG_ROOT
    # This file contains a list of env vars that should not be copied
    # from the driver to the Ray workers.
    non_carry_over_env_vars_file = os.path.join(
        config_home, "ray_non_carry_over_env_vars.json")
    if os.path.exists(non_carry_over_env_vars_file):
        with open(non_carry_over_env_vars_file) as f:
            non_carry_over_env_vars = set(json.load(f))
    else:
        non_carry_over_env_vars = set()
70

71
72
    uses_ray: bool = True

73
    def _init_executor(self) -> None:
74
        self.forward_dag: Optional[ray.dag.CompiledDAG] = None
75
        if envs.VLLM_USE_V1:
76
            # V1 uses SPMD worker and compiled DAG
77
78
            os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
            os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
79
80
81
82
83

            # For TPU, avoid compiling NVIDIA's NCCL
            if current_platform.is_tpu():
                os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0"

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        # If the env var is set, it uses the Ray's compiled DAG API
        # which optimizes the control plane overhead.
        # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
        # Currently, this requires USE_RAY_SPMD_WORKER=True.
        self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
        # If the env var is set, then we do not distinguish between the
        # "driver worker" vs other workers. Also, the rank 0 worker will
        # be executed in a remote Ray worker. Currently this requires
        # USE_RAY_COMPILED_DAG=True.
        self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
        if self.use_ray_compiled_dag:
            assert self.use_ray_spmd_worker, (
                "VLLM_USE_RAY_COMPILED_DAG=1 requires "
                "VLLM_USE_RAY_SPMD_WORKER=1")
        if self.use_ray_spmd_worker:
            # TODO: Support SPMD worker for non-DAG Ray executor.
            assert self.use_ray_compiled_dag, (
                "VLLM_USE_RAY_SPMD_WORKER=1 requires "
                "VLLM_USE_RAY_COMPILED_DAG=1")

104
        assert self.uses_ray
105
        initialize_ray_cluster(self.parallel_config)
106
107
108
109
110
111
112
113
114
115
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

116
117
118
        self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
        self.output_decoder = msgspec.msgpack.Decoder(
            Optional[List[SamplerOutput]])
119
120
121
122
123
124
        self.use_v1 = envs.VLLM_USE_V1

        self.pp_locks: Optional[List[asyncio.Lock]] = None
        if not self.use_ray_compiled_dag:
            self.driver_exec_method = make_async(
                self.driver_worker.execute_method)
125

126
    def shutdown(self) -> None:
127
128
129
130
        logger.info(
            "Shutting down Ray distributed executor. If you see error log "
            "from logging.cc regarding SIGTERM received, please ignore because "
            "this is the expected termination process in Ray.")
131
132
133
134
135
136
137
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    def _configure_ray_workers_use_nsight(self,
                                          ray_remote_kwargs) -> Dict[str, Any]:
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
        runtime_env.update({
            "nsight": {
                "t": "cuda,cudnn,cublas",
                "o": "'worker_process_%p'",
                "cuda-graph-trace": "node",
            }
        })

        return ray_remote_kwargs

153
154
155
156
    # child class could overwrite this to return actual env vars.
    def _get_env_vars_to_be_updated(self):
        return self._env_vars_for_all_workers

157
158
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
159
        num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
160
161
162

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
163
        self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
164
        # The remaining workers are the actual ray actors.
165
        self.workers: List[RayWorkerWrapper] = []
166

167
168
169
170
171
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
        self.pp_tp_workers: List[List[RayWorkerWrapper]] = []

172
173
174
175
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
                ray_remote_kwargs)

176
        logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
177

178
        # Create the workers.
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        bundle_indices: List[int]
        if envs.VLLM_RAY_BUNDLE_INDICES:
            # Use the bundle indices specified by the user.
            bundle_indices = list(
                map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
            assert len(bundle_indices) == self.parallel_config.world_size, \
            ("VLLM_RAY_BUNDLE_INDICES must have the same size"
            f" as the world size, but got {bundle_indices=} "
            f"and {self.parallel_config.world_size=}")
            assert len(set(bundle_indices)) == len(bundle_indices), \
            ("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
            f" but got {bundle_indices=}")
        else:
            # use the first N bundles that have GPU resources.
            bundle_indices = []
            for bundle_id, bundle in enumerate(placement_group.bundle_specs):
                if bundle.get(current_platform.ray_device_key, 0):
                    bundle_indices.append(bundle_id)
            bundle_indices = bundle_indices[:self.parallel_config.world_size]

199
        worker_metadata: List[RayWorkerMetaData] = []
200
201
        driver_ip = get_ip()
        for rank, bundle_id in enumerate(bundle_indices):
202
203
204
205
206
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
207

208
209
210
211
212
213
214
215
            if current_platform.ray_device_key == "GPU":
                # NV+AMD GPUs, and Intel XPUs
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
                )(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
216
                                           rpc_rank=rank)
217
218
219
220
221
222
223
224
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={current_platform.ray_device_key: num_gpus},
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
                )(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
225
                                           rpc_rank=rank)
226
227
228
229
230
231
232
233
234
235
            worker_metadata.append(
                RayWorkerMetaData(worker=worker, created_rank=rank))

        worker_ips = ray.get([
            each.worker.get_node_ip.remote()  # type: ignore[attr-defined]
            for each in worker_metadata
        ])

        for each, ip in zip(worker_metadata, worker_ips):
            each.ip = ip
236
237

        if not self.use_ray_spmd_worker:
238
239
240
241
            for i, each in enumerate(worker_metadata):
                # find and remove the dummy worker from the list
                worker = each.worker
                worker_ip = each.ip
242
                if self.driver_dummy_worker is None and worker_ip == driver_ip:
243
244
245
246
                    # If the worker is on the same node as the driver, we use it
                    # as the resource holder for the driver process.
                    self.driver_dummy_worker = worker
                    self.driver_worker = RayWorkerWrapper(
247
                        vllm_config=self.vllm_config, rpc_rank=0)
248
                    worker_metadata.pop(i)
249
                    break
250

251
        logger.debug("workers: %s", worker_metadata)
252
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
253
        if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
254
            raise ValueError(
255
256
257
258
                "Ray does not allocate any GPUs on the driver node."
                f"Driver IP: {driver_ip}, worker IPs: {worker_ips}."
                "Consider adjusting the Ray placement group or running "
                "the driver on a GPU node.")
259

260
261
262
263
        ip_counts: Dict[str, int] = {}
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

264
        def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
265
266
267
268
269
270
271
272
273
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
274
275
            ip = item.ip
            return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
276
277
278
279

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
280
281
282
283
284
285
286
287
288
289
290
        sorted_worker_metadata = sorted(worker_metadata,
                                        key=sort_by_driver_then_worker_ip)
        start_rank = 0 if self.use_ray_spmd_worker else 1
        for i, item in enumerate(sorted_worker_metadata):
            item.adjusted_rank = i + start_rank
        self.workers = [item.worker for item in sorted_worker_metadata]
        rerank_mapping = {
            item.created_rank: item.adjusted_rank
            for item in sorted_worker_metadata
        }
        self._run_workers("adjust_rank", rerank_mapping)
291

292
        # Get the set of GPU IDs used on each node.
293
294
295
296
297
298
299
300
        worker_node_and_gpu_ids = []
        for worker in [self.driver_dummy_worker] + self.workers:
            if worker is None:
                # driver_dummy_worker can be None when using ray spmd worker.
                continue
            worker_node_and_gpu_ids.append(
                ray.get(worker.get_node_and_gpu_ids.remote()) \
            ) # type: ignore
301

302
303
304
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

305
306
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
307
308
309
310
311
312
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
313
314
315
316
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

317
318
319
320
321
322
323
324
325
        all_ips = set(worker_ips + [driver_ip])
        n_ips = len(all_ips)
        n_nodes = len(node_workers)

        if n_nodes != n_ips:
            raise RuntimeError(
                f"Every node should have a unique IP address. Got {n_nodes}"
                f" nodes with node ids {list(node_workers.keys())} and "
                f"{n_ips} unique IP addresses {all_ips}. Please check your"
326
327
                " network configuration. If you set `VLLM_HOST_IP`"
                " environment variable, make sure it is unique for"
328
329
                " each node.")

330
        # Set environment variables for the driver and workers.
331
332
        all_args_to_update_environment_variables = [{
            current_platform.device_control_env_var:
333
            ",".join(map(str, node_gpus[node_id])),
334
335
        } for (node_id, _) in worker_node_and_gpu_ids]

336
337
        # Environment variables to copy from driver to workers
        env_vars_to_copy = [
338
339
340
            v for v in envs.environment_variables
            if v not in self.WORKER_SPECIFIC_ENV_VARS
            and v not in self.non_carry_over_env_vars
341
342
        ]

343
344
        env_vars_to_copy.extend(current_platform.additional_env_vars)

345
        # Copy existing env vars to each worker's args
346
347
        for args in all_args_to_update_environment_variables:
            # TODO: refactor platform-specific env vars
348
            for name in env_vars_to_copy:
349
350
                if name in os.environ:
                    args[name] = os.environ[name]
351

352
353
        logger.info("non_carry_over_env_vars from config: %s",
                    self.non_carry_over_env_vars)
354
355
356
        logger.info(
            "Copying the following environment variables to workers: %s",
            [v for v in env_vars_to_copy if v in os.environ])
357
358
359
        logger.info(
            "If certain env vars should NOT be copied to workers, add them to "
            "%s file", self.non_carry_over_env_vars_file)
360

361
362
363
        self._env_vars_for_all_workers = (
            all_args_to_update_environment_variables)

364
        self._run_workers("update_environment_variables",
365
                          self._get_env_vars_to_be_updated())
366

367
368
369
370
371
372
373
374
375
376
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
377
378
379
        distributed_init_method = get_distributed_init_method(
            driver_ip, get_open_port())

380
        # Initialize the actual workers inside worker wrapper.
381
382
383
384
385
386
        all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
            local_rank = node_workers[node_id].index(rank)
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
387
388
                rank=rank,
                distributed_init_method=distributed_init_method,
389
390
391
392
393
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
        self._run_workers("init_worker", all_kwargs)
394

395
        self._run_workers("init_device")
396
397
398
        self._run_workers("load_model",
                          max_concurrent_workers=self.parallel_config.
                          max_parallel_loading_workers)
399

400
401
402
403
404
405
406
407
408
409
410
411
412
        if self.use_ray_spmd_worker:
            for pp_rank in range(self.parallel_config.pipeline_parallel_size):
                self.pp_tp_workers.append([])
                for tp_rank in range(
                        self.parallel_config.tensor_parallel_size):
                    # PP=2, TP=4
                    # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
                    rank = (pp_rank * self.parallel_config.tensor_parallel_size
                            ) + tp_rank
                    assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                    assert pp_rank < len(self.pp_tp_workers)
                    self.pp_tp_workers[pp_rank].append(self.workers[rank])

413
414
415
416
417
418
419
420
421
        # This is the list of workers that are rank 0 of each TP group EXCEPT
        # global rank 0. These are the workers that will broadcast to the
        # rest of the workers.
        self.tp_driver_workers: List[RayWorkerWrapper] = []
        # This is the list of workers that are not drivers and not the first
        # worker in a TP group. These are the workers that will be
        # broadcasted to.
        self.non_driver_workers: List[RayWorkerWrapper] = []

422
        # Enforce rank order for correct rank to return final output.
423
424
425
        for index, worker in enumerate(self.workers):
            # The driver worker is rank 0 and not in self.workers.
            rank = index + 1
426
            if rank % self.parallel_config.tensor_parallel_size == 0:
427
                self.tp_driver_workers.append(worker)
428
            else:
429
                self.non_driver_workers.append(worker)
430

431
    def _driver_execute_model(
432
433
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
434
        """Run execute_model in the driver worker.
435

436
437
438
        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
439
440
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
441
442
        return self.driver_worker.execute_method("execute_model",
                                                 execute_model_req)
443

444
445
446
447
448
449
450
451
452
    def execute_model(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return super().execute_model(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)

453
454
455
456
        if self.use_v1:
            serialized_data = execute_model_req
        else:
            serialized_data = self.input_encoder.encode(execute_model_req)
457
        outputs = ray.get(self.forward_dag.execute(serialized_data))
458
459
460
461
        if self.use_v1:
            output = outputs[0]
        else:
            output = self.output_decoder.decode(outputs[0])
462
        return output
463

464
465
    def _run_workers(
        self,
466
        method: Union[str, Callable],
467
        *args,
468
        async_run_tensor_parallel_workers_only: bool = False,
469
470
471
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
472
473
474
        """Runs the given method on all workers. Can be used in the following
        ways:

475
476
477
478
479
        Args:
        - async_run_tensor_parallel_workers_only: If True the method will be
          run only in the remote TP workers, not the driver worker.
          It will also be run asynchronously and return a list of futures
          rather than blocking on the results.
480
        - args/kwargs: All workers share the same args/kwargs
481
        """
482
483
484
485
486
        if isinstance(method, str):
            sent_method = method
        else:
            sent_method = cloudpickle.dumps(method)
        del method
487
488
489
490
        if self.use_ray_spmd_worker:
            assert not async_run_tensor_parallel_workers_only, (
                "async_run_tensor_parallel_workers_only is not supported for "
                "spmd mode.")
491
492
493
494
495

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

496
497
498
499
500
        # Start the ray workers first.
        ray_workers = self.workers
        if async_run_tensor_parallel_workers_only:
            ray_workers = self.non_driver_workers
        ray_worker_outputs = [
501
            worker.execute_method.remote(sent_method, *args, **kwargs)
502
            for worker in ray_workers
503
        ]
504

505
        if async_run_tensor_parallel_workers_only:
506
507
508
            # Just return futures
            return ray_worker_outputs

509
510
511
512
513
514
        driver_worker_output = []
        # In SPMD mode, the driver worker is the same as any other worker,
        # so we only explicitly execute on the driver worker if using a
        # non-SPMD worker class.
        if not self.use_ray_spmd_worker:
            # Start the driver worker after all the ray workers.
515
            driver_worker_output = [
516
                self.driver_worker.execute_method(sent_method, *args, **kwargs)
517
            ]
518

519
520
        # Get the results of the ray workers.
        if self.workers:
521
            ray_worker_outputs = ray.get(ray_worker_outputs)
522

523
        return driver_worker_output + ray_worker_outputs
524

525
526
527
528
529
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        ray.get(parallel_worker_tasks)

530
    def _check_ray_cgraph_installation(self):
531
        import pkg_resources
532
533
        from packaging import version

Rui Qiao's avatar
Rui Qiao committed
534
        required_version = version.parse("2.43.0")
535
536
        current_version = version.parse(
            pkg_resources.get_distribution("ray").version)
537
        if current_version < required_version:
538
            raise ValueError(f"Ray version {required_version} is "
539
540
                             f"required, but found {current_version}")

541
        import importlib.util
542
        cgraph_spec = importlib.util.find_spec(
543
            "ray.experimental.compiled_dag_ref")
544
545
        if cgraph_spec is None:
            raise ValueError("Ray Compiled Graph is not installed. "
Rui Qiao's avatar
Rui Qiao committed
546
                             "Run `pip install ray[cgraph]` to install it.")
547
548
549
550
551

        cupy_spec = importlib.util.find_spec("cupy")
        if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
            raise ValueError(
                "cupy is not installed but required since "
552
                "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set. "
Rui Qiao's avatar
Rui Qiao committed
553
                "Run `pip install ray[cgraph]` and check cupy installation.")
554
555

    def _compiled_ray_dag(self, enable_asyncio: bool):
556
        assert self.parallel_config.use_ray
557
        self._check_ray_cgraph_installation()
558
        from ray.dag import InputNode, MultiOutputNode
559

560
561
        logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s",
                    envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
562
563
        logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
                    envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
564
        with InputNode() as input_data:
565
            # Example DAG: PP=2, TP=4
566
567
568
569
570
571
572
573
574
575
576
577
            #
            # For V0:
            # ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput   # noqa: E501
            # ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput   # noqa: E501
            #
            # For V1:
            # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput   # noqa: E501
578
579
580
581
582
583
584

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
585
586
                if self.use_v1:
                    outputs = [
587
                        worker.execute_model_ray.
588
589
590
591
592
593
594
595
596
                        bind(  # type: ignore[attr-defined]
                            outputs[i]) for i, worker in enumerate(tp_group)
                    ]
                else:
                    outputs = [
                        worker.execute_model_spmd.
                        bind(  # type: ignore[attr-defined]
                            outputs[i]) for i, worker in enumerate(tp_group)
                    ]
597
598
599
600
601
602
603
604
605
606

                last_pp_rank = len(self.pp_tp_workers) - 1
                if pp_rank < last_pp_rank:
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
                    # pp stage.
                    transport = "nccl" \
                        if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
                        else "auto"
                    outputs = [
Rui Qiao's avatar
Rui Qiao committed
607
                        output.with_tensor_transport(transport=transport)
608
609
610
611
612
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

613
614
615
616
        return forward_dag.experimental_compile(
            enable_asyncio=enable_asyncio,
            _overlap_gpu_communication=envs.
            VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
617
618

    def __del__(self):
619
        self.shutdown()
620

621
622
623
624
625
626
627
628
629
    async def execute_model_async(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return await super().execute_model_async(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)

630
631
        serialized_data = self.input_encoder.encode(execute_model_req)
        dag_future = await self.forward_dag.execute_async(serialized_data)
632
633
        output = await dag_future[0]
        return self.output_decoder.decode(output)
634

635
    async def _driver_execute_model_async(
636
        self,
637
638
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
639
640
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
641
642
643
        if not self.tp_driver_workers:
            return await self.driver_exec_method("execute_model",
                                                 execute_model_req)
644
645
646
647
648
649
650
651
652
        if self.pp_locks is None:
            # This locks each pipeline parallel stage so multiple virtual
            # engines can't execute on the same stage at the same time
            # We create the locks here to avoid creating them in the constructor
            # which uses a different asyncio loop.
            self.pp_locks = [
                asyncio.Lock()
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]
653

654
        tasks = [
655
656
            asyncio.create_task(
                _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
657
658
                                    "execute_model", execute_model_req))
        ]
659
660
661
662
663
664
665
666
667
668
669
670
        for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
                                                start=1):
            tasks.append(
                asyncio.create_task(
                    _run_task_with_lock(driver_worker.execute_method.remote,
                                        self.pp_locks[pp_rank],
                                        "execute_model", execute_model_req)))

        results = await asyncio.gather(*tasks)

        # Only the last PP stage has the final results.
        return results[-1]
671
672

    async def _start_worker_execution_loop(self):
673
674
        assert not self.use_ray_spmd_worker, (
            "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
675
676
        coros = [
            worker.execute_method.remote("start_worker_execution_loop")
677
            for worker in self.non_driver_workers
678
679
        ]
        return await asyncio.gather(*coros)
680

681
682
683
684
    def check_health(self) -> None:
        # Assume that the Ray workers are healthy.
        # TODO: check the health of the Ray workers
        return