ray_gpu_executor.py 23.3 KB
Newer Older
1
2
import asyncio
import os
3
from collections import defaultdict
4
from itertools import islice, repeat
5
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
6

7
8
import msgspec

9
import vllm.envs as envs
10
11
from vllm.executor.distributed_gpu_executor import (  # yapf: disable
    DistributedGPUExecutor, DistributedGPUExecutorAsync)
12
from vllm.executor.msgspec_utils import encode_hook
13
from vllm.executor.ray_utils import RayWorkerWrapper, ray
14
from vllm.logger import init_logger
15
from vllm.sequence import ExecuteModelRequest, SamplerOutput
16
17
18
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
                        get_ip, get_open_port, get_vllm_instance_id,
                        make_async)
19
20
21
22
23
24
25
26
27
28

if ray is not None:
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)


29
class RayGPUExecutor(DistributedGPUExecutor):
30

31
32
    uses_ray: bool = True

33
    def _init_executor(self) -> None:
34
        self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        # If the env var is set, it uses the Ray's compiled DAG API
        # which optimizes the control plane overhead.
        # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
        # Currently, this requires USE_RAY_SPMD_WORKER=True.
        self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
        # If the env var is set, then we do not distinguish between the
        # "driver worker" vs other workers. Also, the rank 0 worker will
        # be executed in a remote Ray worker. Currently this requires
        # USE_RAY_COMPILED_DAG=True.
        self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
        if self.use_ray_compiled_dag:
            assert self.use_ray_spmd_worker, (
                "VLLM_USE_RAY_COMPILED_DAG=1 requires "
                "VLLM_USE_RAY_SPMD_WORKER=1")
        if self.use_ray_spmd_worker:
            # TODO: Support SPMD worker for non-DAG Ray executor.
            assert self.use_ray_compiled_dag, (
                "VLLM_USE_RAY_SPMD_WORKER=1 requires "
                "VLLM_USE_RAY_COMPILED_DAG=1")

55
        assert self.uses_ray
56
57
58
59
60
61
62
63
64
65
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

66
67
68
69
        self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
        self.output_decoder = msgspec.msgpack.Decoder(
            Optional[List[SamplerOutput]])

70
71
72
73
74
75
76
77
    def shutdown(self) -> None:
        if hasattr(self, "forward_dag") and self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
            for worker in self.workers:
                ray.kill(worker)
            self.forward_dag = None

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    def _configure_ray_workers_use_nsight(self,
                                          ray_remote_kwargs) -> Dict[str, Any]:
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
        runtime_env.update({
            "nsight": {
                "t": "cuda,cudnn,cublas",
                "o": "'worker_process_%p'",
                "cuda-graph-trace": "node",
            }
        })

        return ray_remote_kwargs

93
94
95
96
97
98
99
100
101
102
103
104
105
106
    def _get_worker_wrapper_args(self) -> Dict[str, Any]:
        if self.speculative_config is not None:
            worker_module_name = "vllm.spec_decode.spec_decode_worker"
            worker_class_name = "create_spec_worker"
        else:
            worker_module_name = "vllm.worker.worker"
            worker_class_name = "Worker"

        return dict(
            worker_module_name=worker_module_name,
            worker_class_name=worker_class_name,
            trust_remote_code=self.model_config.trust_remote_code,
        )

107
108
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
109
110
        if (self.parallel_config.tensor_parallel_size == 1
                and self.parallel_config.pipeline_parallel_size == 1):
111
112
113
114
115
116
117
118
            # For single GPU case, we use a ray worker with constrained memory.
            num_gpus = self.cache_config.gpu_memory_utilization
        else:
            # Otherwise, the ray workers are allocated with a full GPU.
            num_gpus = 1

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
119
        self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
120
        # The remaining workers are the actual ray actors.
121
        self.workers: List[RayWorkerWrapper] = []
122

123
124
125
126
127
        # Used in ray compiled DAG: indexed first by PP rank,
        # and then TP rank. In other words, the inner list is
        # the TP group of workers for a PP rank.
        self.pp_tp_workers: List[List[RayWorkerWrapper]] = []

128
129
130
131
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
                ray_remote_kwargs)

132
        logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
133

134
135
        # Create the workers.
        driver_ip = get_ip()
136
        worker_wrapper_kwargs = self._get_worker_wrapper_args()
137
138
139
140
141
142
143
144
        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
            if not bundle.get("GPU", 0):
                continue
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
145

146
147
148
149
150
            worker = ray.remote(
                num_cpus=0,
                num_gpus=num_gpus,
                scheduling_strategy=scheduling_strategy,
                **ray_remote_kwargs,
151
            )(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
152

153
            if self.use_ray_spmd_worker:
154
                self.workers.append(worker)
155
156
157
158
159
160
161
            else:
                worker_ip = ray.get(worker.get_node_ip.remote())
                if worker_ip == driver_ip and self.driver_dummy_worker is None:
                    # If the worker is on the same node as the driver, we use it
                    # as the resource holder for the driver process.
                    self.driver_dummy_worker = worker
                    self.driver_worker = RayWorkerWrapper(
162
                        **worker_wrapper_kwargs)
163
164
165
166
                else:
                    # Else, added to the list of workers.
                    self.workers.append(worker)

167
168
        logger.debug("workers: %s", self.workers)
        logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
169
        if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
170
171
172
173
174
            raise ValueError(
                "Ray does not allocate any GPUs on the driver node. Consider "
                "adjusting the Ray placement group or running the driver on a "
                "GPU node.")

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        worker_ips = [
            ray.get(worker.get_node_ip.remote())  # type: ignore[attr-defined]
            for worker in self.workers
        ]
        ip_counts: Dict[str, int] = {}
        for ip in worker_ips:
            ip_counts[ip] = ip_counts.get(ip, 0) + 1

        def sort_by_driver_then_worker_ip(worker):
            """
            Sort the workers based on 3 properties:
            1. If the worker is on the same node as the driver (vllm engine),
                it should be placed first.
            2. Then, if the worker is on a node with fewer workers, it should
                be placed first.
            3. Finally, if the work is on a node with smaller IP address, it
                should be placed first.
            """
            ip = ray.get(worker.get_node_ip.remote())
            return (ip != driver_ip, ip_counts[ip], ip)

        # After sorting, the workers on the same node will be
        # close to each other, and the workers on the driver
        # node will be placed first.
        self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)

201
        # Get the set of GPU IDs used on each node.
202
203
        worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
                                                    use_dummy_driver=True)
204

205
206
207
        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

208
209
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
            node_workers[node_id].append(i)
210
211
212
213
214
215
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
216
217
218
219
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

220
221
222
        VLLM_INSTANCE_ID = get_vllm_instance_id()

        # Set environment variables for the driver and workers.
223
224
225
226
227
228
        all_args_to_update_environment_variables = [({
            "CUDA_VISIBLE_DEVICES":
            ",".join(map(str, node_gpus[node_id])),
            "VLLM_INSTANCE_ID":
            VLLM_INSTANCE_ID,
            "VLLM_TRACE_FUNCTION":
229
            str(envs.VLLM_TRACE_FUNCTION),
230
        }, ) for (node_id, _) in worker_node_and_gpu_ids]
231
232
        self._run_workers("update_environment_variables",
                          all_args=all_args_to_update_environment_variables)
233

234
235
236
237
238
239
240
241
242
243
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
244
245
246
        distributed_init_method = get_distributed_init_method(
            driver_ip, get_open_port())

247
        # Initialize the actual workers inside worker wrapper.
248
249
250
251
252
        init_worker_all_kwargs = [
            self._get_worker_kwargs(
                local_rank=node_workers[node_id].index(rank),
                rank=rank,
                distributed_init_method=distributed_init_method,
253
            ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
254
        ]
255
        self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
256

257
        self._run_workers("init_device")
258
259
260
        self._run_workers("load_model",
                          max_concurrent_workers=self.parallel_config.
                          max_parallel_loading_workers)
261

262
263
264
265
266
267
268
269
270
271
272
273
274
        if self.use_ray_spmd_worker:
            for pp_rank in range(self.parallel_config.pipeline_parallel_size):
                self.pp_tp_workers.append([])
                for tp_rank in range(
                        self.parallel_config.tensor_parallel_size):
                    # PP=2, TP=4
                    # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
                    rank = (pp_rank * self.parallel_config.tensor_parallel_size
                            ) + tp_rank
                    assert len(self.pp_tp_workers[pp_rank]) == tp_rank
                    assert pp_rank < len(self.pp_tp_workers)
                    self.pp_tp_workers[pp_rank].append(self.workers[rank])

275
276
277
278
279
280
281
282
283
        # This is the list of workers that are rank 0 of each TP group EXCEPT
        # global rank 0. These are the workers that will broadcast to the
        # rest of the workers.
        self.tp_driver_workers: List[RayWorkerWrapper] = []
        # This is the list of workers that are not drivers and not the first
        # worker in a TP group. These are the workers that will be
        # broadcasted to.
        self.non_driver_workers: List[RayWorkerWrapper] = []

284
        # Enforce rank order for correct rank to return final output.
285
286
287
        for index, worker in enumerate(self.workers):
            # The driver worker is rank 0 and not in self.workers.
            rank = index + 1
288
            if rank % self.parallel_config.tensor_parallel_size == 0:
289
                self.tp_driver_workers.append(worker)
290
            else:
291
                self.non_driver_workers.append(worker)
292

293
    def _driver_execute_model(
294
295
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
296
        """Run execute_model in the driver worker.
297

298
299
300
        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
301
302
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
303
304
        return self.driver_worker.execute_method("execute_model",
                                                 execute_model_req)
305

306
307
308
309
310
311
312
313
314
    def execute_model(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return super().execute_model(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)

315
316
317
318
        serialized_data = self.input_encoder.encode(execute_model_req)
        outputs = ray.get(self.forward_dag.execute(serialized_data))
        output = self.output_decoder.decode(outputs[0])
        return output
319

320
321
322
323
    def _run_workers(
        self,
        method: str,
        *args,
324
        async_run_tensor_parallel_workers_only: bool = False,
325
        all_args: Optional[List[Tuple[Any, ...]]] = None,
326
327
        all_kwargs: Optional[List[Dict[str, Any]]] = None,
        use_dummy_driver: bool = False,
328
329
330
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
331
332
333
        """Runs the given method on all workers. Can be used in the following
        ways:

334
335
336
337
338
        Args:
        - async_run_tensor_parallel_workers_only: If True the method will be
          run only in the remote TP workers, not the driver worker.
          It will also be run asynchronously and return a list of futures
          rather than blocking on the results.
339
340
341
        - args/kwargs: All workers share the same args/kwargs
        - all_args/all_kwargs: args/kwargs for each worker are specified
          individually
342
        """
343
344
345
346
        if self.use_ray_spmd_worker:
            assert not async_run_tensor_parallel_workers_only, (
                "async_run_tensor_parallel_workers_only is not supported for "
                "spmd mode.")
347
348
349
350
351

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

352
353
354
        count = len(self.workers) if not \
            async_run_tensor_parallel_workers_only \
            else len(self.non_driver_workers)
355
356
357
358
        # If using SPMD worker, all workers are the same, so we should execute
        # the args on all workers. Otherwise, we skip the first worker's args
        # because those args will go to the driver worker.
        first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
359
        all_worker_args = repeat(args, count) if all_args is None \
360
            else islice(all_args, first_worker_args_index, None)
361
        all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
362
363
364
365
366
367
368
369
370
371
372
            else islice(all_kwargs, first_worker_args_index, None)

        # Start the ray workers first.
        ray_workers = self.workers
        if async_run_tensor_parallel_workers_only:
            ray_workers = self.non_driver_workers
        ray_worker_outputs = [
            worker.execute_method.remote(method, *worker_args, **worker_kwargs)
            for (worker, worker_args, worker_kwargs
                 ) in zip(ray_workers, all_worker_args, all_worker_kwargs)
        ]
373

374
        if async_run_tensor_parallel_workers_only:
375
376
377
            # Just return futures
            return ray_worker_outputs

378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        driver_worker_output = []
        # In SPMD mode, the driver worker is the same as any other worker,
        # so we only explicitly execute on the driver worker if using a
        # non-SPMD worker class.
        if not self.use_ray_spmd_worker:
            driver_args = args if all_args is None else all_args[0]
            driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]

            # Start the driver worker after all the ray workers.
            if not use_dummy_driver:
                driver_worker_output = [
                    self.driver_worker.execute_method(method, *driver_args,
                                                      **driver_kwargs)
                ]
            else:
                assert self.driver_dummy_worker is not None
                driver_worker_output = [
                    ray.get(
                        self.driver_dummy_worker.execute_method.remote(
                            method, *driver_args, **driver_kwargs))
                ]
399

400
401
        # Get the results of the ray workers.
        if self.workers:
402
            ray_worker_outputs = ray.get(ray_worker_outputs)
403

404
        return driver_worker_output + ray_worker_outputs
405

406
407
408
409
410
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        ray.get(parallel_worker_tasks)

411
    def _compiled_ray_dag(self, enable_asyncio: bool):
412
        import pkg_resources
413
414
415
416
417
        from packaging import version

        required_version = version.parse("2.32")
        current_version = version.parse(
            pkg_resources.get_distribution("ray").version)
418
419
420
421
        if current_version < required_version:
            raise ValueError(f"Ray version {required_version} or greater is "
                             f"required, but found {current_version}")

422
        assert self.parallel_config.use_ray
423
424
        from ray.dag import InputNode, MultiOutputNode
        from ray.experimental.channel.torch_tensor_type import TorchTensorType
425

426
427
        logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s",
                    envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
428
        with InputNode() as input_data:
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
            # Example DAG: PP=2, TP=4
            # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput   # noqa: E501
            #                         -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput   # noqa: E501
            #                         -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput   # noqa: E501
            #                         -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput   # noqa: E501

            # All workers in the first TP group will take in the
            # ExecuteModelRequest as input.
            outputs = [input_data for _ in self.pp_tp_workers[0]]
            for pp_rank, tp_group in enumerate(self.pp_tp_workers):
                # Each PP worker takes in the output of the previous PP worker,
                # and the TP group executes in SPMD fashion.
                outputs = [
                    worker.execute_model_spmd.
                    bind(  # type: ignore[attr-defined]
                        outputs[i]) for i, worker in enumerate(tp_group)
                ]

                last_pp_rank = len(self.pp_tp_workers) - 1
                if pp_rank < last_pp_rank:
                    # Specify how intermediate tensors should be passed
                    # between pp stages, no need to specify for the last
                    # pp stage.
                    transport = "nccl" \
                        if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
                        else "auto"
                    outputs = [
                        output.with_type_hint(
                            TorchTensorType(transport=transport))
                        for output in outputs
                    ]

            forward_dag = MultiOutputNode(outputs)

463
464
465
        return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)

    def __del__(self):
466
        self.shutdown()
467
468


469
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
470

471
472
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
473
        self.pp_locks: Optional[List[asyncio.Lock]] = None
474
475
476
477
478
479
480
481
482
483
484
485
486
487
        self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
        if not self.use_ray_compiled_dag:
            self.driver_exec_method = make_async(
                self.driver_worker.execute_method)

    async def execute_model_async(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return await super().execute_model_async(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)

488
489
        serialized_data = self.input_encoder.encode(execute_model_req)
        dag_future = await self.forward_dag.execute_async(serialized_data)
490
        outputs = await dag_future
491
        return self.output_decoder.decode(outputs[0])
492

493
    async def _driver_execute_model_async(
494
        self,
495
496
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
497
498
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
499
500
501
        if not self.tp_driver_workers:
            return await self.driver_exec_method("execute_model",
                                                 execute_model_req)
502
503
504
505
506
507
508
509
510
        if self.pp_locks is None:
            # This locks each pipeline parallel stage so multiple virtual
            # engines can't execute on the same stage at the same time
            # We create the locks here to avoid creating them in the constructor
            # which uses a different asyncio loop.
            self.pp_locks = [
                asyncio.Lock()
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]
511

512
        tasks = [
513
514
            asyncio.create_task(
                _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
515
516
                                    "execute_model", execute_model_req))
        ]
517
518
519
520
521
522
523
524
525
526
527
528
        for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
                                                start=1):
            tasks.append(
                asyncio.create_task(
                    _run_task_with_lock(driver_worker.execute_method.remote,
                                        self.pp_locks[pp_rank],
                                        "execute_model", execute_model_req)))

        results = await asyncio.gather(*tasks)

        # Only the last PP stage has the final results.
        return results[-1]
529
530

    async def _start_worker_execution_loop(self):
531
532
        assert not self.use_ray_spmd_worker, (
            "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
533
534
        coros = [
            worker.execute_method.remote("start_worker_execution_loop")
535
            for worker in self.non_driver_workers
536
537
        ]
        return await asyncio.gather(*coros)
538
539

    def __del__(self):
540
        self.shutdown()