ray_executor.py 22.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from collections import defaultdict
6
from collections.abc import Callable
7
from concurrent.futures import Future
8
from dataclasses import dataclass
9
from typing import TYPE_CHECKING, Any
10

11
import cloudpickle
12

13
import vllm.envs as envs
14
from vllm.logger import init_logger
15
from vllm.platforms import current_platform
16
from vllm.ray.ray_env import get_env_vars_to_copy
17
from vllm.utils.network_utils import (
18
19
20
21
    get_distributed_init_method,
    get_ip,
    get_open_port,
)
22
23
24
25
26
27
28
29
30
31
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (
    FutureWrapper,
    RayWorkerWrapper,
    initialize_ray_cluster,
    ray,
)
from vllm.v1.outputs import ModelRunnerOutput
32
33

if ray is not None:
34
    from ray.actor import ActorHandle
35
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
36
37
else:
    ActorHandle = None
38
39
40
41
42
43
44

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)


45
46
47
48
49
50
51
@dataclass
class RayWorkerMetaData:
    """
    Metadata for a Ray worker.
    The order of ray worker creation can be random,
    and we need to reset the rank after creating all workers.
    """
52

53
54
55
56
57
58
    worker: ActorHandle
    created_rank: int
    adjusted_rank: int = -1
    ip: str = ""


59
class RayDistributedExecutor(Executor):
60
61
62
63
64
    """Ray-based distributed executor"""

    # These env vars are worker-specific, therefore are NOT copied
    # from the driver to the workers
    WORKER_SPECIFIC_ENV_VARS = {
65
66
67
68
        "VLLM_HOST_IP",
        "VLLM_HOST_PORT",
        "LOCAL_RANK",
        "CUDA_VISIBLE_DEVICES",
69
70
    }

71
72
73
    # These non-vLLM env vars are copied from the driver to workers
    ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}

74
    uses_ray: bool = True
75
    supports_pp: bool = True
76

77
    def _init_executor(self) -> None:
78
        self.forward_dag: ray.dag.CompiledDAG | None = None
79
80
81
82

        # For TPU or XPU, avoid compiling NVIDIA's NCCL
        if current_platform.is_tpu() or current_platform.is_xpu():
            os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
83

84
        assert self.uses_ray
85
        initialize_ray_cluster(self.parallel_config)
86
87
88
89
90
91
92
93
94
95
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

96
97
        # KV connector setup
        self.has_connector = self.vllm_config.kv_transfer_config is not None
98

99
100
101
102
103
104
105
106
    @property
    def max_concurrent_batches(self) -> int:
        """Ray distributed executor supports pipeline parallelism,
        meaning that it allows PP size batches to be executed concurrently.
        """
        if self.scheduler_config.async_scheduling:
            return 2
        return self.parallel_config.pipeline_parallel_size
107

108
    def shutdown(self) -> None:
109
110
111
112
113
        if logger:
            # Somehow logger can be None here.
            logger.info(
                "Shutting down Ray distributed executor. If you see error log "
                "from logging.cc regarding SIGTERM received, please ignore "
114
115
                "because this is the expected termination process in Ray."
            )
116
117
118
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
119

120
121
122
123
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

124
    def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]:
125
126
127
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
128
129
130
131
132
133
134
        runtime_env.update(
            {
                "nsight": {
                    "t": "cuda,cudnn,cublas",
                    "o": "'worker_process_%p'",
                    "cuda-graph-trace": "node",
                }
135
            }
136
        )
137
138
139

        return ray_remote_kwargs

140
141
142
143
    # child class could overwrite this to return actual env vars.
    def _get_env_vars_to_be_updated(self):
        return self._env_vars_for_all_workers

144
    def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
145
        num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
146
147
148

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
149
        self.driver_dummy_worker: RayWorkerWrapper | None = None
150
        # The remaining workers are the actual ray actors.
151
        self.workers: list[RayWorkerWrapper] = []
152

153
154
155
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
156
        self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
157

158
159
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
160
161
                ray_remote_kwargs
            )
162

163
        # Create the workers.
164
        bundle_indices: list[int]
165
166
        if envs.VLLM_RAY_BUNDLE_INDICES:
            # Use the bundle indices specified by the user.
167
168
169
170
171
172
173
174
175
176
            bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
            assert len(bundle_indices) == self.parallel_config.world_size, (
                "VLLM_RAY_BUNDLE_INDICES must have the same size"
                f" as the world size, but got {bundle_indices=} "
                f"and {self.parallel_config.world_size=}"
            )
            assert len(set(bundle_indices)) == len(bundle_indices), (
                "VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
                f" but got {bundle_indices=}"
            )
177
178
179
180
181
182
        else:
            # use the first N bundles that have GPU resources.
            bundle_indices = []
            for bundle_id, bundle in enumerate(placement_group.bundle_specs):
                if bundle.get(current_platform.ray_device_key, 0):
                    bundle_indices.append(bundle_id)
183
            bundle_indices = bundle_indices[: self.parallel_config.world_size]
184

185
        worker_metadata: list[RayWorkerMetaData] = []
186
187
        driver_ip = get_ip()
        for rank, bundle_id in enumerate(bundle_indices):
188
189
190
191
192
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
193

194
195
196
197
198
199
200
            if current_platform.ray_device_key == "GPU":
                # NV+AMD GPUs, and Intel XPUs
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
201
202
203
                )(RayWorkerWrapper).remote(  # type: ignore[attr-defined]
                    vllm_config=self.vllm_config, rpc_rank=rank
                )
204
205
206
207
208
209
210
            else:
                worker = ray.remote(
                    num_cpus=0,
                    num_gpus=0,
                    resources={current_platform.ray_device_key: num_gpus},
                    scheduling_strategy=scheduling_strategy,
                    **ray_remote_kwargs,
211
212
213
                )(RayWorkerWrapper).remote(  # type: ignore[attr-defined]
                    vllm_config=self.vllm_config, rpc_rank=rank
                )
214
            worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
215

216
217
218
219
220
221
        worker_ips = ray.get(
            [
                each.worker.get_node_ip.remote()  # type: ignore[attr-defined]
                for each in worker_metadata
            ]
        )
222
223
224

        for each, ip in zip(worker_metadata, worker_ips):
            each.ip = ip
225

226
        logger.debug("workers: %s", worker_metadata)
227
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
228

229
        ip_counts: dict[str, int] = {}
230
231
232
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

233
        def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
234
235
236
237
238
239
240
241
242
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
243
            ip = item.ip
244
            return 0 if ip == driver_ip else 1, ip_counts[ip], ip
245
246
247
248

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
249
250
251
        sorted_worker_metadata = sorted(
            worker_metadata, key=sort_by_driver_then_worker_ip
        )
252
        for i, item in enumerate(sorted_worker_metadata):
253
            item.adjusted_rank = i
254
255
        self.workers = [item.worker for item in sorted_worker_metadata]
        rerank_mapping = {
256
            item.created_rank: item.adjusted_rank for item in sorted_worker_metadata
257
        }
258
        self.collective_rpc("adjust_rank", args=(rerank_mapping,))
259

260
        # Get the set of GPU IDs used on each node.
261
262
263
264
265
266
        worker_node_and_gpu_ids = []
        for worker in [self.driver_dummy_worker] + self.workers:
            if worker is None:
                # driver_dummy_worker can be None when using ray spmd worker.
                continue
            worker_node_and_gpu_ids.append(
267
                ray.get(worker.get_node_and_gpu_ids.remote())
268
            )  # type: ignore[attr-defined]
269

270
271
272
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

273
274
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
275
276
277
278
279
280
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
281
282
283
284
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

285
286
287
288
289
290
291
292
293
        all_ips = set(worker_ips + [driver_ip])
        n_ips = len(all_ips)
        n_nodes = len(node_workers)

        if n_nodes != n_ips:
            raise RuntimeError(
                f"Every node should have a unique IP address. Got {n_nodes}"
                f" nodes with node ids {list(node_workers.keys())} and "
                f"{n_ips} unique IP addresses {all_ips}. Please check your"
294
295
                " network configuration. If you set `VLLM_HOST_IP`"
                " environment variable, make sure it is unique for"
296
297
                " each node."
            )
298

299
        # Set environment variables for the driver and workers.
300
301
302
303
304
305
306
307
        all_args_to_update_environment_variables = [
            {
                current_platform.device_control_env_var: ",".join(
                    map(str, node_gpus[node_id])
                ),
            }
            for (node_id, _) in worker_node_and_gpu_ids
        ]
308

309
        # Environment variables to copy from driver to workers
310
311
        env_vars_to_copy = get_env_vars_to_copy(
            exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
312
            additional_vars=set(current_platform.additional_env_vars).union(
313
314
315
316
                self.ADDITIONAL_ENV_VARS
            ),
            destination="workers",
        )
317

318
        # Copy existing env vars to each worker's args
319
320
        for args in all_args_to_update_environment_variables:
            # TODO: refactor platform-specific env vars
321
            for name in env_vars_to_copy:
322
323
                if name in os.environ:
                    args[name] = os.environ[name]
324

325
        self._env_vars_for_all_workers = all_args_to_update_environment_variables
326

327
328
        self.collective_rpc(
            "update_environment_variables", args=(self._get_env_vars_to_be_updated(),)
329
        )
330

331
332
333
334
335
336
337
338
339
340
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
341
        distributed_init_method = get_distributed_init_method(
342
343
            driver_ip, get_open_port()
        )
344

345
        # Initialize the actual workers inside worker wrapper.
346
347
348
349
350
351
        all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
            local_rank = node_workers[node_id].index(rank)
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
352
353
                rank=rank,
                distributed_init_method=distributed_init_method,
354
355
356
357
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        self.collective_rpc("init_worker", args=(all_kwargs,))

        self.collective_rpc("init_device")
        self.collective_rpc("load_model")

        for pp_rank in range(self.parallel_config.pipeline_parallel_size):
            self.pp_tp_workers.append([])
            for tp_rank in range(self.parallel_config.tensor_parallel_size):
                # PP=2, TP=4
                # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
                rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank
                assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                assert pp_rank < len(self.pp_tp_workers)
                self.pp_tp_workers[pp_rank].append(self.workers[rank])

    def reinitialize_distributed(
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
        self.collective_rpc("reinitialize_distributed", args=(reconfig_request,))
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
            self.shutdown()

    def execute_model(  # type: ignore[override]
        self, scheduler_output: SchedulerOutput, non_block: bool = False
    ) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
        """Execute the model on the Ray workers.
387

388
389
390
        Args:
            scheduler_output: The scheduler output to execute.
            non_block: If True, the method will return a Future.
391

392
393
394
395
396
397
        Returns:
            The model runner output.
        """
        # Build the compiled DAG for the first time.
        if self.forward_dag is None:  # type: ignore
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
398

399
        refs = self.forward_dag.execute(scheduler_output)  # type: ignore
400

401
402
403
404
405
        if not self.has_connector:
            # Get output only from a single worker (output_rank)
            # When PP is not used, we block here until the result is available.
            if not non_block:
                return refs[0].get()
406

407
408
409
            # When PP is used, we return a FutureWrapper immediately so that
            # the scheduler can yield to the next batch.
            return FutureWrapper(refs)
410

411
412
413
414
415
416
        # Get output from all workers when connector is present
        assert self.kv_output_aggregator is not None
        if not non_block:
            # Block and get results from all workers
            outputs = [ref.get() for ref in refs]
            return self.kv_output_aggregator.aggregate(outputs)
417

418
419
        # Return a future that will aggregate outputs from all workers
        return FutureWrapper(refs, self.kv_output_aggregator)
420

421
    def collective_rpc(
422
        self,
423
        method: str | Callable,
424
425
426
427
428
429
        timeout: float | None = None,
        args: tuple = (),
        kwargs: dict[str, Any] | None = None,
        non_block: bool = False,
    ) -> list[Any]:
        """Runs the given method on all workers."""
430
        sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
431
        del method
432

433
434
        if kwargs is None:
            kwargs = {}
435
        ray_worker_outputs = [
436
437
438
            worker.execute_method.remote(  # type: ignore[attr-defined]
                sent_method, *args, **kwargs
            )
439
            for worker in self.workers
440
        ]
441
442

        # Get the results of the ray workers.
443
444
        if non_block:
            return [FutureWrapper((output,)) for output in ray_worker_outputs]
445

446
        return ray.get(ray_worker_outputs, timeout=timeout)
447

448
    def _check_ray_cgraph_installation(self):
449
450
        import importlib.metadata

451
452
        from packaging import version

Rui Qiao's avatar
Rui Qiao committed
453
        required_version = version.parse("2.43.0")
454
        current_version = version.parse(importlib.metadata.version("ray"))
455
        if current_version < required_version:
456
457
458
459
            raise ValueError(
                f"Ray version {required_version} is "
                f"required, but found {current_version}"
            )
460

461
        import importlib.util
462
463

        cgraph_spec = importlib.util.find_spec("ray.experimental.compiled_dag_ref")
464
        if cgraph_spec is None:
465
466
467
468
            raise ValueError(
                "Ray Compiled Graph is not installed. "
                "Run `pip install ray[cgraph]` to install it."
            )
469
470

        cupy_spec = importlib.util.find_spec("cupy")
471
        if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl":
472
473
            raise ValueError(
                "cupy is not installed but required since "
474
                "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
475
476
                "Run `pip install ray[cgraph]` and check cupy installation."
            )
477
478

    def _compiled_ray_dag(self, enable_asyncio: bool):
479
        assert self.parallel_config.use_ray
480
        self._check_ray_cgraph_installation()
481
482
483
484
485
486
487
488
        # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
        # (it is 10 seconds by default). This is a Ray environment variable to
        # control the timeout of getting result from a compiled graph execution,
        # i.e., the distributed execution that includes model forward runs and
        # intermediate tensor communications, in the case of vllm.
        # Note: we should set this env var before importing
        # ray.dag, otherwise it will not take effect.
        os.environ.setdefault("RAY_CGRAPH_get_timeout", "300")  # noqa: SIM112
489
        from ray.dag import InputNode, MultiOutputNode
490
491

        logger.info(
492
493
494
            "RAY_CGRAPH_get_timeout is set to %s",
            os.environ["RAY_CGRAPH_get_timeout"],  # noqa: SIM112
        )
495
496
497
498
499
500
501
502
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE,
        )
        logger.info(
            "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
            envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
503
504
505
506
507

        channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
        if channel_type not in ("auto", "nccl", "shm"):
            raise ValueError(
                "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
508
509
                f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'."
            )
510

511
        with InputNode() as input_data:
512
            # Example DAG: PP=2, TP=4
513
514
515
516
517
            #
            # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput   # noqa: E501
            # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput   # noqa: E501
518
519
520
521
522
523
524

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
525
526
527
528
                outputs = [
                    worker.execute_model_ray.bind(outputs[i])  # type: ignore[attr-defined]
                    for i, worker in enumerate(tp_group)
                ]
529
530

                last_pp_rank = len(self.pp_tp_workers) - 1
531
532
533
534
                if (
                    pp_rank < last_pp_rank
                    and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"
                ):
535
536
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
537
538
                    # pp stage or when using shared memory (the default).
                    transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
539
                    outputs = [
Rui Qiao's avatar
Rui Qiao committed
540
                        output.with_tensor_transport(transport=transport)
541
542
543
544
545
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

546
547
        if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
            from ray.experimental.channel.accelerator_context import (
548
549
                register_accelerator_context,
            )
550
551

            from vllm.distributed.device_communicators.ray_communicator import (
552
553
554
555
556
557
558
559
560
561
562
                RayPPCommunicator,
            )

            register_accelerator_context(
                torch_module_name="cuda", communicator_cls=RayPPCommunicator
            )
            logger.info(
                "Using RayPPCommunicator "
                "(which wraps vLLM _PP GroupCoordinator) "
                "for Ray Compiled Graph communication."
            )
563
        else:
564
565
566
            logger.info(
                "Using Ray's NCCL communicator for Ray Compiled Graph communication."
            )
567

568
569
        return forward_dag.experimental_compile(
            enable_asyncio=enable_asyncio,
570
571
            _overlap_gpu_communication=envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
        )
572
573

    def __del__(self):
574
        self.shutdown()
575

576
577
578
579
    def check_health(self) -> None:
        # Assume that the Ray workers are healthy.
        # TODO: check the health of the Ray workers
        return