ray_gpu_executor.py 16.8 KB
Newer Older
1
2
3
import asyncio
import os
import pickle
4
from collections import defaultdict
5
from itertools import islice, repeat
6
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
7
8

from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
9
from vllm.executor.ray_utils import RayWorkerWrapper, ray
10
11
12
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
13
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
14
                        get_vllm_instance_id, make_async)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

if ray is not None:
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)

# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))


class RayGPUExecutor(ExecutorBase):

32
33
    def _init_executor(self) -> None:
        assert (not self.speculative_config
34
                ), "Speculative decoding not yet supported for RayGPU backend."
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

        assert self.parallel_config.worker_use_ray
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

        self.forward_dag = None
        if USE_RAY_COMPILED_DAG:
            self.forward_dag = self._compiled_ray_dag()

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    def _configure_ray_workers_use_nsight(self,
                                          ray_remote_kwargs) -> Dict[str, Any]:
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
        runtime_env.update({
            "nsight": {
                "t": "cuda,cudnn,cublas",
                "o": "'worker_process_%p'",
                "cuda-graph-trace": "node",
            }
        })

        return ray_remote_kwargs

66
67
68
69
70
71
72
73
74
75
76
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
        if self.parallel_config.tensor_parallel_size == 1:
            # For single GPU case, we use a ray worker with constrained memory.
            num_gpus = self.cache_config.gpu_memory_utilization
        else:
            # Otherwise, the ray workers are allocated with a full GPU.
            num_gpus = 1

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
77
        self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
78
        # The remaining workers are the actual ray actors.
79
        self.workers: List[RayWorkerWrapper] = []
80

81
82
83
84
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
                ray_remote_kwargs)

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        # Create the workers.
        driver_ip = get_ip()
        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
            if not bundle.get("GPU", 0):
                continue
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
            worker = ray.remote(
                num_cpus=0,
                num_gpus=num_gpus,
                scheduling_strategy=scheduling_strategy,
                **ray_remote_kwargs,
100
101
102
            )(RayWorkerWrapper).remote(
                worker_module_name="vllm.worker.worker",
                worker_class_name="Worker",
103
                trust_remote_code=self.model_config.trust_remote_code,
104
            )
105
106
107
108
109
110

            worker_ip = ray.get(worker.get_node_ip.remote())
            if worker_ip == driver_ip and self.driver_dummy_worker is None:
                # If the worker is on the same node as the driver, we use it
                # as the resource holder for the driver process.
                self.driver_dummy_worker = worker
111
112
113
                self.driver_worker = RayWorkerWrapper(
                    worker_module_name="vllm.worker.worker",
                    worker_class_name="Worker",
114
                    trust_remote_code=self.model_config.trust_remote_code,
115
                )
116
117
118
119
120
121
122
123
124
125
126
            else:
                # Else, added to the list of workers.
                self.workers.append(worker)

        if self.driver_dummy_worker is None:
            raise ValueError(
                "Ray does not allocate any GPUs on the driver node. Consider "
                "adjusting the Ray placement group or running the driver on a "
                "GPU node.")

        # Get the set of GPU IDs used on each node.
127
128
        worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
                                                    use_dummy_driver=True)
129
130
131
132

        node_workers = defaultdict(list)
        node_gpus = defaultdict(list)

133
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
134
135
136
137
138
            node_workers[node_id].append(i)
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

139
140
141
        VLLM_INSTANCE_ID = get_vllm_instance_id()

        # Set environment variables for the driver and workers.
142
143
144
145
146
147
148
149
        all_args_to_update_environment_variables = [({
            "CUDA_VISIBLE_DEVICES":
            ",".join(map(str, node_gpus[node_id])),
            "VLLM_INSTANCE_ID":
            VLLM_INSTANCE_ID,
            "VLLM_TRACE_FUNCTION":
            os.getenv("VLLM_TRACE_FUNCTION", "0"),
        }, ) for (node_id, _) in worker_node_and_gpu_ids]
150
151
        self._run_workers("update_environment_variables",
                          all_args=all_args_to_update_environment_variables)
152
153
154
155

        distributed_init_method = get_distributed_init_method(
            driver_ip, get_open_port())

156
157
158
159
160
        def collect_arg_helper_func(**kwargs):
            # avoid writing `{"name": value}` manually
            return kwargs

        # Initialize the actual workers inside worker wrapper.
161
162
        init_worker_all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
163
            local_rank = node_workers[node_id].index(rank)
164
165
166
167
168
169
170
171
            init_worker_all_kwargs.append(
                collect_arg_helper_func(
                    model_config=self.model_config,
                    parallel_config=self.parallel_config,
                    scheduler_config=self.scheduler_config,
                    device_config=self.device_config,
                    cache_config=self.cache_config,
                    load_config=self.load_config,
172
173
174
                    local_rank=local_rank,
                    rank=rank,
                    distributed_init_method=distributed_init_method,
175
176
177
                    lora_config=self.lora_config,
                    vision_language_config=self.vision_language_config,
                    is_driver_worker=rank == 0,
178
                ))
179
        self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
180

181
        self._run_workers("init_device")
182
183
184
185
186
187
        self._run_workers(
            "load_model",
            max_concurrent_workers=self.parallel_config.
            max_parallel_loading_workers,
        )

188
    def determine_num_available_blocks(self) -> Tuple[int, int]:
189
        """Determine the number of available KV blocks.
190

191
192
193
        This invokes `determine_num_available_blocks` on each worker and takes
        the min of the results, guaranteeing that the selected cache sizes are
        compatible with all workers.
194

195
        Returns:
196
            - Tuple[num_gpu_blocks, num_cpu_blocks]
197
198
        """
        # Get the maximum number of blocks that can be allocated on GPU and CPU.
199
        num_blocks = self._run_workers("determine_num_available_blocks", )
200
201
202
203
204
205

        # Since we use a shared centralized controller, we take the minimum
        # number of blocks across all workers to make sure all the memory
        # operators can be applied to all workers.
        num_gpu_blocks = min(b[0] for b in num_blocks)
        num_cpu_blocks = min(b[1] for b in num_blocks)
206

207
        return num_gpu_blocks, num_cpu_blocks
208

209
210
211
212
213
214
215
216
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache in all workers.
        """

        # NOTE: We log here to avoid multiple logs when number of workers is
        # greater than one. We could log in the engine, but not all executors
        # have GPUs.
217
218
219
220
221
222
        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                    f"# CPU blocks: {num_cpu_blocks}")

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

223
224
225
        self._run_workers("initialize_cache",
                          num_gpu_blocks=num_gpu_blocks,
                          num_cpu_blocks=num_cpu_blocks)
226
227
228
229
230

    def execute_model(self,
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      blocks_to_swap_in: Dict[int, int],
                      blocks_to_swap_out: Dict[int, int],
231
232
                      blocks_to_copy: Dict[int, List[int]],
                      num_lookahead_slots: int = 0) -> SamplerOutput:
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        all_outputs = self._run_workers(
            "execute_model",
            driver_kwargs={
                "seq_group_metadata_list": seq_group_metadata_list,
                "blocks_to_swap_in": blocks_to_swap_in,
                "blocks_to_swap_out": blocks_to_swap_out,
                "blocks_to_copy": blocks_to_copy,
            },
            use_ray_compiled_dag=USE_RAY_COMPILED_DAG)

        # Only the driver worker returns the sampling results.
        output = all_outputs[0]
        return output

    def add_lora(self, lora_request: LoRARequest) -> bool:
        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
        return self._run_workers(
            "add_lora",
            lora_request=lora_request,
        )

    def remove_lora(self, lora_id: int) -> bool:
        assert lora_id > 0, "lora_id must be greater than 0."
        return self._run_workers(
            "remove_lora",
            lora_id=lora_id,
        )

261
    def list_loras(self) -> Set[int]:
262
263
264
265
266
267
        return self._run_workers("list_loras")

    def _run_workers(
        self,
        method: str,
        *args,
268
        driver_args: Optional[Tuple[Any, ...]] = None,
269
        driver_kwargs: Optional[Dict[str, Any]] = None,
270
        all_args: Optional[List[Tuple[Any, ...]]] = None,
271
272
        all_kwargs: Optional[List[Dict[str, Any]]] = None,
        use_dummy_driver: bool = False,
273
274
275
276
        max_concurrent_workers: Optional[int] = None,
        use_ray_compiled_dag: bool = False,
        **kwargs,
    ) -> Any:
277
278
279
280
281
282
283
284
        """Runs the given method on all workers. Can be used in the following
        ways:

        - args/kwargs: All workers share the same args/kwargs
        - args/kwargs and driver_args/driver_kwargs: Driver worker has
          different args
        - all_args/all_kwargs: args/kwargs for each worker are specified
          individually
285
        """
286
287
288
289
290

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

291
292
293
294
295
296
297
298
299
300
301
        if driver_args is None:
            driver_args = args if all_args is None else all_args[0]
        if driver_kwargs is None:
            driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]

        count = len(self.workers)
        all_worker_args = repeat(args, count) if all_args is None \
            else islice(all_args, 1, None)
        all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
            else islice(all_kwargs, 1, None)

302
303
304
        if use_ray_compiled_dag:
            # Right now, compiled DAG can only accept a single
            # input. TODO(sang): Fix it.
305
            assert self.forward_dag is not None
306
307
308
309
            output_channels = self.forward_dag.execute(1)
        else:
            # Start the ray workers first.
            ray_worker_outputs = [
310
311
312
                worker.execute_method.remote(method, *worker_args,
                                             **worker_kwargs)
                for (worker, worker_args, worker_kwargs
313
                     ) in zip(self.workers, all_worker_args, all_worker_kwargs)
314
315
316
            ]

        # Start the driver worker after all the ray workers.
317
318
        if not use_dummy_driver:
            driver_worker_output = self.driver_worker.execute_method(
319
                method, *driver_args, **driver_kwargs)
320
        else:
321
            assert self.driver_dummy_worker is not None
322
323
            driver_worker_output = ray.get(
                self.driver_dummy_worker.execute_method.remote(
324
                    method, *driver_args, **driver_kwargs))
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
        # Get the results of the ray workers.
        if self.workers:
            if use_ray_compiled_dag:
                try:
                    ray_worker_outputs = [
                        pickle.loads(chan.begin_read())
                        for chan in output_channels
                    ]
                finally:
                    # Has to call end_read in order to reuse the DAG.
                    for chan in output_channels:
                        chan.end_read()
            else:
                ray_worker_outputs = ray.get(ray_worker_outputs)

        return [driver_worker_output] + ray_worker_outputs

    def _compiled_ray_dag(self):
        import pkg_resources
        required_version = "2.9"
        current_version = pkg_resources.get_distribution("ray").version
        if current_version < required_version:
            raise ValueError(f"Ray version {required_version} or greater is "
                             f"required, but found {current_version}")

350
        from ray.dag import InputNode, MultiOutputNode
351
352
353
354
355
356
        assert self.parallel_config.worker_use_ray

        # Right now, compiled DAG requires at least 1 arg. We send
        # a dummy value for now. It will be fixed soon.
        with InputNode() as input_data:
            forward_dag = MultiOutputNode([
357
358
359
                worker.execute_model_compiled_dag_remote.
                bind(  # type: ignore[attr-defined]
                    input_data) for worker in self.workers
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
            ])
        return forward_dag.experimental_compile()

    def check_health(self) -> None:
        """Raises an error if engine is unhealthy."""
        self._check_if_any_actor_is_dead()

    def _check_if_any_actor_is_dead(self):
        if not self.workers:
            return

        dead_actors = []
        for actor in self.workers:
            actor_state = ray.state.actors(actor._ray_actor_id.hex())  # pylint: disable=protected-access
            if actor_state["State"] == "DEAD":
                dead_actors.append(actor)
        if dead_actors:
            raise RuntimeError("At least one Worker is dead. "
                               f"Dead Workers: {dead_actors}. ")


class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):

383
384
385
386
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.driver_executor = make_async(self.driver_worker.execute_method)

387
388
389
390
    async def _run_workers_async(
        self,
        method: str,
        *args,
391
        driver_args: Optional[Tuple[Any, ...]] = None,
392
393
394
395
396
397
398
399
400
401
402
        driver_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ) -> Any:
        """Runs the given method on all workers."""
        coros = []

        if driver_args is None:
            driver_args = args
        if driver_kwargs is None:
            driver_kwargs = kwargs

403
404
        coros.append(
            self.driver_executor(method, *driver_args, **driver_kwargs))
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426

        # Run the ray workers asynchronously.
        for worker in self.workers:
            coros.append(worker.execute_method.remote(method, *args, **kwargs))

        all_outputs = await asyncio.gather(*coros)
        return all_outputs

    async def execute_model_async(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
        blocks_to_copy: Dict[int, List[int]],
    ) -> SamplerOutput:
        all_outputs = await self._run_workers_async(
            "execute_model",
            driver_kwargs={
                "seq_group_metadata_list": seq_group_metadata_list,
                "blocks_to_swap_in": blocks_to_swap_in,
                "blocks_to_swap_out": blocks_to_swap_out,
                "blocks_to_copy": blocks_to_copy,
427
            })
428
429
430
431

        # Only the driver worker returns the sampling results.
        output = all_outputs[0]
        return output