ray_gpu_executor.py 15.6 KB
Newer Older
1
2
3
import asyncio
import os
import pickle
4
from collections import defaultdict
5
from itertools import islice, repeat
6
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
7

8
import vllm.envs as envs
9
10
from vllm.executor.distributed_gpu_executor import (  # yapf: disable
    DistributedGPUExecutor, DistributedGPUExecutorAsync)
11
from vllm.executor.ray_utils import RayWorkerWrapper, ray
12
from vllm.logger import init_logger
13
from vllm.sequence import ExecuteModelRequest, SamplerOutput
14
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
15
                        get_vllm_instance_id, make_async)
16
17
18
19
20
21
22
23
24

if ray is not None:
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)

25
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
26
27


28
class RayGPUExecutor(DistributedGPUExecutor):
29

30
    def _init_executor(self) -> None:
31
        assert self.parallel_config.distributed_executor_backend == "ray"
32
33
34
35
36
37
38
39
40
41
42
43
44
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

        self.forward_dag = None
        if USE_RAY_COMPILED_DAG:
            self.forward_dag = self._compiled_ray_dag()
45
46
            self.extra_execute_model_run_workers_kwargs[
                "use_ray_compiled_dag"] = True
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    def _configure_ray_workers_use_nsight(self,
                                          ray_remote_kwargs) -> Dict[str, Any]:
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
        runtime_env.update({
            "nsight": {
                "t": "cuda,cudnn,cublas",
                "o": "'worker_process_%p'",
                "cuda-graph-trace": "node",
            }
        })

        return ray_remote_kwargs

63
64
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
65
66
        if (self.parallel_config.tensor_parallel_size == 1
                and self.parallel_config.pipeline_parallel_size == 1):
67
68
69
70
71
72
73
74
            # For single GPU case, we use a ray worker with constrained memory.
            num_gpus = self.cache_config.gpu_memory_utilization
        else:
            # Otherwise, the ray workers are allocated with a full GPU.
            num_gpus = 1

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
75
        self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
76
        # The remaining workers are the actual ray actors.
77
        self.workers: List[RayWorkerWrapper] = []
78

79
80
81
82
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
                ray_remote_kwargs)

83
84
85
86
87
88
89
90
91
92
        # Create the workers.
        driver_ip = get_ip()
        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
            if not bundle.get("GPU", 0):
                continue
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
93
94
95
96
97
98
99
100

            if self.speculative_config is not None:
                worker_module_name = "vllm.spec_decode.spec_decode_worker"
                worker_class_name = "create_spec_worker"
            else:
                worker_module_name = "vllm.worker.worker"
                worker_class_name = "Worker"

101
102
103
104
105
            worker = ray.remote(
                num_cpus=0,
                num_gpus=num_gpus,
                scheduling_strategy=scheduling_strategy,
                **ray_remote_kwargs,
106
            )(RayWorkerWrapper).remote(
107
108
                worker_module_name=worker_module_name,
                worker_class_name=worker_class_name,
109
                trust_remote_code=self.model_config.trust_remote_code,
110
            )
111
112
113
114
115
116

            worker_ip = ray.get(worker.get_node_ip.remote())
            if worker_ip == driver_ip and self.driver_dummy_worker is None:
                # If the worker is on the same node as the driver, we use it
                # as the resource holder for the driver process.
                self.driver_dummy_worker = worker
117
                self.driver_worker = RayWorkerWrapper(
118
119
                    worker_module_name=worker_module_name,
                    worker_class_name=worker_class_name,
120
                    trust_remote_code=self.model_config.trust_remote_code,
121
                )
122
123
124
125
126
127
128
129
130
131
132
            else:
                # Else, added to the list of workers.
                self.workers.append(worker)

        if self.driver_dummy_worker is None:
            raise ValueError(
                "Ray does not allocate any GPUs on the driver node. Consider "
                "adjusting the Ray placement group or running the driver on a "
                "GPU node.")

        # Get the set of GPU IDs used on each node.
133
134
        worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
                                                    use_dummy_driver=True)
135
136
137
138

        node_workers = defaultdict(list)
        node_gpus = defaultdict(list)

139
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
140
            node_workers[node_id].append(i)
141
142
143
144
145
146
            # `gpu_ids` can be a list of strings or integers.
            # convert them to integers for consistency.
            # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
            # string sorting is not sufficient.
            # see https://github.com/vllm-project/vllm/issues/5590
            gpu_ids = [int(x) for x in gpu_ids]
147
148
149
150
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

151
152
153
        VLLM_INSTANCE_ID = get_vllm_instance_id()

        # Set environment variables for the driver and workers.
154
155
156
157
158
159
        all_args_to_update_environment_variables = [({
            "CUDA_VISIBLE_DEVICES":
            ",".join(map(str, node_gpus[node_id])),
            "VLLM_INSTANCE_ID":
            VLLM_INSTANCE_ID,
            "VLLM_TRACE_FUNCTION":
160
            str(envs.VLLM_TRACE_FUNCTION),
161
        }, ) for (node_id, _) in worker_node_and_gpu_ids]
162
163
        self._run_workers("update_environment_variables",
                          all_args=all_args_to_update_environment_variables)
164

165
166
167
168
169
170
171
172
173
174
        if len(node_gpus) == 1:
            # in single node case, we don't need to get the IP address.
            # the loopback address is sufficient
            # NOTE: a node may have several IP addresses, one for each
            # network interface. `get_ip()` might return any of them,
            # while they might not work for communication inside the node
            # if the network setup is complicated. Using the loopback address
            # solves this issue, as it always works for communication inside
            # the node.
            driver_ip = "127.0.0.1"
175
176
177
        distributed_init_method = get_distributed_init_method(
            driver_ip, get_open_port())

178
        # Initialize the actual workers inside worker wrapper.
179
180
181
182
183
184
185
        init_worker_all_kwargs = [
            self._get_worker_kwargs(
                local_rank=node_workers[node_id].index(rank),
                rank=rank,
                distributed_init_method=distributed_init_method,
            ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
        ]
186
        self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
187

188
        self._run_workers("init_device")
189
190
191
        self._run_workers("load_model",
                          max_concurrent_workers=self.parallel_config.
                          max_parallel_loading_workers)
192

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        # This is the list of workers that are rank 0 of each TP group EXCEPT
        # global rank 0. These are the workers that will broadcast to the
        # rest of the workers.
        self.tp_driver_workers: List[RayWorkerWrapper] = []
        # This is the list of workers that are not drivers and not the first
        # worker in a TP group. These are the workers that will be
        # broadcasted to.
        self.non_driver_workers: List[RayWorkerWrapper] = []

        for pp_rank in range(self.parallel_config.pipeline_parallel_size):
            for tp_rank in range(self.parallel_config.tensor_parallel_size):
                rank = (pp_rank *
                        self.parallel_config.tensor_parallel_size) + tp_rank
                if rank == 0:
                    pass
                elif rank % self.parallel_config.tensor_parallel_size == 0:
                    self.tp_driver_workers.append(self.workers[rank - 1])
                else:
                    self.non_driver_workers.append(self.workers[rank - 1])

213
    def _driver_execute_model(
214
215
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
216
        """Run execute_model in the driver worker.
217

218
219
220
221
222
        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
        return self.driver_worker.execute_method("execute_model",
                                                 execute_model_req)
223
224
225
226
227

    def _run_workers(
        self,
        method: str,
        *args,
228
        async_run_tensor_parallel_workers_only: bool = False,
229
        all_args: Optional[List[Tuple[Any, ...]]] = None,
230
231
        all_kwargs: Optional[List[Dict[str, Any]]] = None,
        use_dummy_driver: bool = False,
232
233
234
235
        max_concurrent_workers: Optional[int] = None,
        use_ray_compiled_dag: bool = False,
        **kwargs,
    ) -> Any:
236
237
238
        """Runs the given method on all workers. Can be used in the following
        ways:

239
240
241
242
243
        Args:
        - async_run_tensor_parallel_workers_only: If True the method will be
          run only in the remote TP workers, not the driver worker.
          It will also be run asynchronously and return a list of futures
          rather than blocking on the results.
244
245
246
        - args/kwargs: All workers share the same args/kwargs
        - all_args/all_kwargs: args/kwargs for each worker are specified
          individually
247
        """
248
249
250
251
252

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

253
254
255
        count = len(self.workers) if not \
            async_run_tensor_parallel_workers_only \
            else len(self.non_driver_workers)
256
257
258
259
260
        all_worker_args = repeat(args, count) if all_args is None \
            else islice(all_args, 1, None)
        all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
            else islice(all_kwargs, 1, None)

261
262
263
        if use_ray_compiled_dag:
            # Right now, compiled DAG can only accept a single
            # input. TODO(sang): Fix it.
264
            assert self.forward_dag is not None
265
            output_channels = self.forward_dag.execute(1)
266
            ray_worker_outputs = []
267
268
        else:
            # Start the ray workers first.
269
270
271
            ray_workers = self.workers
            if async_run_tensor_parallel_workers_only:
                ray_workers = self.non_driver_workers
272
            ray_worker_outputs = [
273
274
275
                worker.execute_method.remote(method, *worker_args,
                                             **worker_kwargs)
                for (worker, worker_args, worker_kwargs
276
                     ) in zip(ray_workers, all_worker_args, all_worker_kwargs)
277
278
            ]

279
        if async_run_tensor_parallel_workers_only:
280
281
282
283
284
285
            # Just return futures
            return ray_worker_outputs

        driver_args = args if all_args is None else all_args[0]
        driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]

286
        # Start the driver worker after all the ray workers.
287
288
        if not use_dummy_driver:
            driver_worker_output = self.driver_worker.execute_method(
289
                method, *driver_args, **driver_kwargs)
290
        else:
291
            assert self.driver_dummy_worker is not None
292
293
            driver_worker_output = ray.get(
                self.driver_dummy_worker.execute_method.remote(
294
                    method, *driver_args, **driver_kwargs))
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        # Get the results of the ray workers.
        if self.workers:
            if use_ray_compiled_dag:
                try:
                    ray_worker_outputs = [
                        pickle.loads(chan.begin_read())
                        for chan in output_channels
                    ]
                finally:
                    # Has to call end_read in order to reuse the DAG.
                    for chan in output_channels:
                        chan.end_read()
            else:
                ray_worker_outputs = ray.get(ray_worker_outputs)

        return [driver_worker_output] + ray_worker_outputs

312
313
314
315
316
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        ray.get(parallel_worker_tasks)

317
318
319
320
321
322
323
324
    def _compiled_ray_dag(self):
        import pkg_resources
        required_version = "2.9"
        current_version = pkg_resources.get_distribution("ray").version
        if current_version < required_version:
            raise ValueError(f"Ray version {required_version} or greater is "
                             f"required, but found {current_version}")

325
        from ray.dag import InputNode, MultiOutputNode
326
        assert self.parallel_config.distributed_executor_backend == "ray"
327
328
329
330
331

        # Right now, compiled DAG requires at least 1 arg. We send
        # a dummy value for now. It will be fixed soon.
        with InputNode() as input_data:
            forward_dag = MultiOutputNode([
332
333
334
                worker.execute_model_compiled_dag_remote.
                bind(  # type: ignore[attr-defined]
                    input_data) for worker in self.workers
335
336
337
338
            ])
        return forward_dag.experimental_compile()


339
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
340

341
342
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
343
        self.driver_exec_method = make_async(self.driver_worker.execute_method)
344

345
    async def _driver_execute_model_async(
346
        self,
347
348
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370

        async def _run_task_with_lock(task, lock, *args, **kwargs):
            async with lock:
                return await task(*args, **kwargs)

        tasks = []
        tasks.append(
            asyncio.create_task(
                _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
                                    "execute_model", execute_model_req)))
        for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
                                                start=1):
            tasks.append(
                asyncio.create_task(
                    _run_task_with_lock(driver_worker.execute_method.remote,
                                        self.pp_locks[pp_rank],
                                        "execute_model", execute_model_req)))

        results = await asyncio.gather(*tasks)

        # Only the last PP stage has the final results.
        return results[-1]
371
372
373
374

    async def _start_worker_execution_loop(self):
        coros = [
            worker.execute_method.remote("start_worker_execution_loop")
375
            for worker in self.non_driver_workers
376
377
        ]
        return await asyncio.gather(*coros)