ray_gpu_executor.py 20.7 KB
Newer Older
1
2
import asyncio
import os
3
from collections import defaultdict
4
from itertools import islice, repeat
5
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
6

7
import vllm.envs as envs
8
9
from vllm.executor.distributed_gpu_executor import (  # yapf: disable
    DistributedGPUExecutor, DistributedGPUExecutorAsync)
10
from vllm.executor.ray_utils import RayWorkerWrapper, ray
11
from vllm.logger import init_logger
12
from vllm.sequence import ExecuteModelRequest, SamplerOutput
13
14
from vllm.utils import (_run_task_with_lock,
                        error_on_invalid_device_count_status,
15
                        get_distributed_init_method, get_ip, get_open_port,
16
                        get_vllm_instance_id, make_async)
17
18
19
20
21
22
23
24
25
26

if ray is not None:
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)


27
class RayGPUExecutor(DistributedGPUExecutor):
28

29
30
    uses_ray: bool = True

31
    def _init_executor(self) -> None:
32
        self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        # If the env var is set, it uses the Ray's compiled DAG API
        # which optimizes the control plane overhead.
        # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
        # Currently, this requires USE_RAY_SPMD_WORKER=True.
        self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
        # If the env var is set, then we do not distinguish between the
        # "driver worker" vs other workers. Also, the rank 0 worker will
        # be executed in a remote Ray worker. Currently this requires
        # USE_RAY_COMPILED_DAG=True.
        self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
        if self.use_ray_compiled_dag:
            assert self.use_ray_spmd_worker, (
                "VLLM_USE_RAY_COMPILED_DAG=1 requires "
                "VLLM_USE_RAY_SPMD_WORKER=1")
        if self.use_ray_spmd_worker:
            # TODO: Support SPMD worker for non-DAG Ray executor.
            assert self.use_ray_compiled_dag, (
                "VLLM_USE_RAY_SPMD_WORKER=1 requires "
                "VLLM_USE_RAY_COMPILED_DAG=1")

53
        assert self.uses_ray
54
55
56
57
58
59
60
61
62
63
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    def _configure_ray_workers_use_nsight(self,
                                          ray_remote_kwargs) -> Dict[str, Any]:
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
        runtime_env.update({
            "nsight": {
                "t": "cuda,cudnn,cublas",
                "o": "'worker_process_%p'",
                "cuda-graph-trace": "node",
            }
        })

        return ray_remote_kwargs

79
80
81
82
83
84
85
86
87
88
89
90
91
92
    def _get_worker_wrapper_args(self) -> Dict[str, Any]:
        if self.speculative_config is not None:
            worker_module_name = "vllm.spec_decode.spec_decode_worker"
            worker_class_name = "create_spec_worker"
        else:
            worker_module_name = "vllm.worker.worker"
            worker_class_name = "Worker"

        return dict(
            worker_module_name=worker_module_name,
            worker_class_name=worker_class_name,
            trust_remote_code=self.model_config.trust_remote_code,
        )

93
94
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
95
96
        if (self.parallel_config.tensor_parallel_size == 1
                and self.parallel_config.pipeline_parallel_size == 1):
97
98
99
100
101
102
103
104
            # For single GPU case, we use a ray worker with constrained memory.
            num_gpus = self.cache_config.gpu_memory_utilization
        else:
            # Otherwise, the ray workers are allocated with a full GPU.
            num_gpus = 1

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
105
        self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
106
        # The remaining workers are the actual ray actors.
107
        self.workers: List[RayWorkerWrapper] = []
108

109
110
111
112
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
                ray_remote_kwargs)

113
114
        # Create the workers.
        driver_ip = get_ip()
115
        worker_wrapper_kwargs = self._get_worker_wrapper_args()
116
117
118
119
120
121
122
123
        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
            if not bundle.get("GPU", 0):
                continue
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
124

125
126
127
128
129
            worker = ray.remote(
                num_cpus=0,
                num_gpus=num_gpus,
                scheduling_strategy=scheduling_strategy,
                **ray_remote_kwargs,
130
            )(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
131

132
            if self.use_ray_spmd_worker:
133
                self.workers.append(worker)
134
135
136
137
138
139
140
            else:
                worker_ip = ray.get(worker.get_node_ip.remote())
                if worker_ip == driver_ip and self.driver_dummy_worker is None:
                    # If the worker is on the same node as the driver, we use it
                    # as the resource holder for the driver process.
                    self.driver_dummy_worker = worker
                    self.driver_worker = RayWorkerWrapper(
141
                        **worker_wrapper_kwargs)
142
143
144
145
146
                else:
                    # Else, added to the list of workers.
                    self.workers.append(worker)

        if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
147
148
149
150
151
152
            raise ValueError(
                "Ray does not allocate any GPUs on the driver node. Consider "
                "adjusting the Ray placement group or running the driver on a "
                "GPU node.")

        # Get the set of GPU IDs used on each node.
153
154
        worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
                                                    use_dummy_driver=True)
155

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        # the order in `worker_node_and_gpu_ids` does not necessarily match
        # the machine boundaries. We need to make sure that workers in the
        # same node are assigned consecutive ranks.
        # examples:
        # [('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [1]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [2]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [3]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [1]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [2]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [3])] # noqa

        # initialize worker ranks with -1 (unassigned)
        worker_ranks = [-1 for x in worker_node_and_gpu_ids]
        current_rank = 0
        while -1 in worker_ranks:
            # whenever we find an unassigned worker, find the node
            index = worker_ranks.index(-1)
            current_node_id = worker_node_and_gpu_ids[index][0]
            # assign ranks to all workers in the same node
            for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
                if node_id == current_node_id:
                    worker_ranks[i] = current_rank
                    current_rank += 1
        # with the above example, worker_ranks will be [0, 4, 5, 6, 7, 1, 2, 3]

        node_workers = defaultdict(list)  # node id -> list of worker ranks
        node_gpus = defaultdict(list)  # node id -> list of gpu ids

        for worker_rank, (node_id, gpu_ids) in zip(worker_ranks,
                                                   worker_node_and_gpu_ids):
            node_workers[node_id].append(worker_rank)
182
183
184
185
186
187
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
188
189
190
191
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

192
193
194
        VLLM_INSTANCE_ID = get_vllm_instance_id()

        # Set environment variables for the driver and workers.
195
196
197
198
199
200
        all_args_to_update_environment_variables = [({
            "CUDA_VISIBLE_DEVICES":
            ",".join(map(str, node_gpus[node_id])),
            "VLLM_INSTANCE_ID":
            VLLM_INSTANCE_ID,
            "VLLM_TRACE_FUNCTION":
201
            str(envs.VLLM_TRACE_FUNCTION),
202
        }, ) for (node_id, _) in worker_node_and_gpu_ids]
203
204
        self._run_workers("update_environment_variables",
                          all_args=all_args_to_update_environment_variables)
205

206
207
208
209
210
211
212
213
214
215
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
216
217
218
        distributed_init_method = get_distributed_init_method(
            driver_ip, get_open_port())

219
220
        error_on_invalid_device_count_status()

221
        # Initialize the actual workers inside worker wrapper.
222
223
224
225
226
        init_worker_all_kwargs = [
            self._get_worker_kwargs(
                local_rank=node_workers[node_id].index(rank),
                rank=rank,
                distributed_init_method=distributed_init_method,
227
228
            ) for rank, (node_id,
                         _) in zip(worker_ranks, worker_node_and_gpu_ids)
229
        ]
230
        self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
231

232
        self._run_workers("init_device")
233
234
235
        self._run_workers("load_model",
                          max_concurrent_workers=self.parallel_config.
                          max_parallel_loading_workers)
236

237
238
239
240
241
242
243
244
245
        # This is the list of workers that are rank 0 of each TP group EXCEPT
        # global rank 0. These are the workers that will broadcast to the
        # rest of the workers.
        self.tp_driver_workers: List[RayWorkerWrapper] = []
        # This is the list of workers that are not drivers and not the first
        # worker in a TP group. These are the workers that will be
        # broadcasted to.
        self.non_driver_workers: List[RayWorkerWrapper] = []

246
247
        # Enforce rank order for correct rank to return final output.
        for rank, worker in sorted(zip(worker_ranks[1:], self.workers)):
248
249
250
            # We need to skip the driver worker, which we
            # do by skipping worker_ranks[0] which is always 0.
            if rank % self.parallel_config.tensor_parallel_size == 0:
251
                self.tp_driver_workers.append(worker)
252
            else:
253
                self.non_driver_workers.append(worker)
254

255
    def _driver_execute_model(
256
257
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
258
        """Run execute_model in the driver worker.
259

260
261
262
        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
263
264
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
265
266
        return self.driver_worker.execute_method("execute_model",
                                                 execute_model_req)
267

268
269
270
271
272
273
274
275
276
277
278
279
    def execute_model(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return super().execute_model(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)

        outputs = ray.get(self.forward_dag.execute(execute_model_req))
        return outputs[0]

280
281
282
283
    def _run_workers(
        self,
        method: str,
        *args,
284
        async_run_tensor_parallel_workers_only: bool = False,
285
        all_args: Optional[List[Tuple[Any, ...]]] = None,
286
287
        all_kwargs: Optional[List[Dict[str, Any]]] = None,
        use_dummy_driver: bool = False,
288
289
290
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
291
292
293
        """Runs the given method on all workers. Can be used in the following
        ways:

294
295
296
297
298
        Args:
        - async_run_tensor_parallel_workers_only: If True the method will be
          run only in the remote TP workers, not the driver worker.
          It will also be run asynchronously and return a list of futures
          rather than blocking on the results.
299
300
301
        - args/kwargs: All workers share the same args/kwargs
        - all_args/all_kwargs: args/kwargs for each worker are specified
          individually
302
        """
303
304
305
306
        if self.use_ray_spmd_worker:
            assert not async_run_tensor_parallel_workers_only, (
                "async_run_tensor_parallel_workers_only is not supported for "
                "spmd mode.")
307
308
309
310
311

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

312
313
314
        count = len(self.workers) if not \
            async_run_tensor_parallel_workers_only \
            else len(self.non_driver_workers)
315
316
317
318
        # If using SPMD worker, all workers are the same, so we should execute
        # the args on all workers. Otherwise, we skip the first worker's args
        # because those args will go to the driver worker.
        first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
319
        all_worker_args = repeat(args, count) if all_args is None \
320
            else islice(all_args, first_worker_args_index, None)
321
        all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
322
323
324
325
326
327
328
329
330
331
332
            else islice(all_kwargs, first_worker_args_index, None)

        # Start the ray workers first.
        ray_workers = self.workers
        if async_run_tensor_parallel_workers_only:
            ray_workers = self.non_driver_workers
        ray_worker_outputs = [
            worker.execute_method.remote(method, *worker_args, **worker_kwargs)
            for (worker, worker_args, worker_kwargs
                 ) in zip(ray_workers, all_worker_args, all_worker_kwargs)
        ]
333

334
        if async_run_tensor_parallel_workers_only:
335
336
337
            # Just return futures
            return ray_worker_outputs

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
        driver_worker_output = []
        # In SPMD mode, the driver worker is the same as any other worker,
        # so we only explicitly execute on the driver worker if using a
        # non-SPMD worker class.
        if not self.use_ray_spmd_worker:
            driver_args = args if all_args is None else all_args[0]
            driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]

            # Start the driver worker after all the ray workers.
            if not use_dummy_driver:
                driver_worker_output = [
                    self.driver_worker.execute_method(method, *driver_args,
                                                      **driver_kwargs)
                ]
            else:
                assert self.driver_dummy_worker is not None
                driver_worker_output = [
                    ray.get(
                        self.driver_dummy_worker.execute_method.remote(
                            method, *driver_args, **driver_kwargs))
                ]
359

360
361
        # Get the results of the ray workers.
        if self.workers:
362
            ray_worker_outputs = ray.get(ray_worker_outputs)
363

364
        return driver_worker_output + ray_worker_outputs
365

366
367
368
369
370
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        ray.get(parallel_worker_tasks)

371
    def _compiled_ray_dag(self, enable_asyncio: bool):
372
        import pkg_resources
373
374
375
376
377
        from packaging import version

        required_version = version.parse("2.32")
        current_version = version.parse(
            pkg_resources.get_distribution("ray").version)
378
379
380
381
        if current_version < required_version:
            raise ValueError(f"Ray version {required_version} or greater is "
                             f"required, but found {current_version}")

382
        from ray.dag import InputNode, MultiOutputNode
383
        assert self.parallel_config.use_ray
384
385
386
387
388

        # Right now, compiled DAG requires at least 1 arg. We send
        # a dummy value for now. It will be fixed soon.
        with InputNode() as input_data:
            forward_dag = MultiOutputNode([
389
                worker.execute_model_spmd.bind(  # type: ignore[attr-defined]
390
                    input_data) for worker in self.workers
391
            ])
392
393
394
395
396
397
398
399
        return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)

    def __del__(self):
        if self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
            for worker in self.workers:
                ray.kill(worker)
400
401


402
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
403

404
405
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
406
        self.pp_locks: Optional[List[asyncio.Lock]] = None
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
        if not self.use_ray_compiled_dag:
            self.driver_exec_method = make_async(
                self.driver_worker.execute_method)

    async def execute_model_async(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
        if not self.use_ray_spmd_worker:
            return await super().execute_model_async(execute_model_req)

        if self.forward_dag is None:
            self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)

        dag_future = await self.forward_dag.execute_async(execute_model_req)
        outputs = await dag_future
        return outputs[0]
424

425
    async def _driver_execute_model_async(
426
        self,
427
428
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
429
430
        assert not self.use_ray_spmd_worker, (
            "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
431
432
433
        if not self.tp_driver_workers:
            return await self.driver_exec_method("execute_model",
                                                 execute_model_req)
434
435
436
437
438
439
440
441
442
        if self.pp_locks is None:
            # This locks each pipeline parallel stage so multiple virtual
            # engines can't execute on the same stage at the same time
            # We create the locks here to avoid creating them in the constructor
            # which uses a different asyncio loop.
            self.pp_locks = [
                asyncio.Lock()
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]
443

444
        tasks = [
445
446
            asyncio.create_task(
                _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
447
448
                                    "execute_model", execute_model_req))
        ]
449
450
451
452
453
454
455
456
457
458
459
460
        for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
                                                start=1):
            tasks.append(
                asyncio.create_task(
                    _run_task_with_lock(driver_worker.execute_method.remote,
                                        self.pp_locks[pp_rank],
                                        "execute_model", execute_model_req)))

        results = await asyncio.gather(*tasks)

        # Only the last PP stage has the final results.
        return results[-1]
461
462

    async def _start_worker_execution_loop(self):
463
464
        assert not self.use_ray_spmd_worker, (
            "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
465
466
        coros = [
            worker.execute_method.remote("start_worker_execution_loop")
467
            for worker in self.non_driver_workers
468
469
        ]
        return await asyncio.gather(*coros)
470
471
472
473
474
475
476

    def __del__(self):
        if self.forward_dag is not None:
            self.forward_dag.teardown()
            import ray
            for worker in self.workers:
                ray.kill(worker)