ray_gpu_executor.py 16.7 KB
Newer Older
1
2
3
import asyncio
import os
import pickle
4
from collections import defaultdict
5
from itertools import islice, repeat
6
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
7

8
from vllm.engine.ray_utils import RayWorkerWrapper, ray
9
10
11
12
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
13
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
14
                        get_vllm_instance_id, make_async)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

if ray is not None:
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

if TYPE_CHECKING:
    from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)

# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))


class RayGPUExecutor(ExecutorBase):

32
33
    def _init_executor(self) -> None:
        assert (not self.speculative_config
34
                ), "Speculative decoding not yet supported for RayGPU backend."
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

        assert self.parallel_config.worker_use_ray
        placement_group = self.parallel_config.placement_group

        # Disable Ray usage stats collection.
        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
        if ray_usage != "1":
            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

        # Create the parallel GPU workers.
        self._init_workers_ray(placement_group)

        self.forward_dag = None
        if USE_RAY_COMPILED_DAG:
            self.forward_dag = self._compiled_ray_dag()

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    def _configure_ray_workers_use_nsight(self,
                                          ray_remote_kwargs) -> Dict[str, Any]:
        # If nsight profiling is enabled, we need to set the profiling
        # configuration for the ray workers as runtime env.
        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
        runtime_env.update({
            "nsight": {
                "t": "cuda,cudnn,cublas",
                "o": "'worker_process_%p'",
                "cuda-graph-trace": "node",
            }
        })

        return ray_remote_kwargs

66
67
68
69
70
71
72
73
74
75
76
    def _init_workers_ray(self, placement_group: "PlacementGroup",
                          **ray_remote_kwargs):
        if self.parallel_config.tensor_parallel_size == 1:
            # For single GPU case, we use a ray worker with constrained memory.
            num_gpus = self.cache_config.gpu_memory_utilization
        else:
            # Otherwise, the ray workers are allocated with a full GPU.
            num_gpus = 1

        # The driver dummy worker does not actually use any resources.
        # It holds the resource for the driver worker.
77
        self.driver_dummy_worker: RayWorkerWrapper = None
78
        # The remaining workers are the actual ray actors.
79
        self.workers: List[RayWorkerWrapper] = []
80

81
82
83
84
        if self.parallel_config.ray_workers_use_nsight:
            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
                ray_remote_kwargs)

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        # Create the workers.
        driver_ip = get_ip()
        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
            if not bundle.get("GPU", 0):
                continue
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=placement_group,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=bundle_id,
            )
            worker = ray.remote(
                num_cpus=0,
                num_gpus=num_gpus,
                scheduling_strategy=scheduling_strategy,
                **ray_remote_kwargs,
100
101
102
            )(RayWorkerWrapper).remote(
                worker_module_name="vllm.worker.worker",
                worker_class_name="Worker",
103
                trust_remote_code=self.model_config.trust_remote_code,
104
            )
105
106
107
108
109
110

            worker_ip = ray.get(worker.get_node_ip.remote())
            if worker_ip == driver_ip and self.driver_dummy_worker is None:
                # If the worker is on the same node as the driver, we use it
                # as the resource holder for the driver process.
                self.driver_dummy_worker = worker
111
112
113
                self.driver_worker = RayWorkerWrapper(
                    worker_module_name="vllm.worker.worker",
                    worker_class_name="Worker",
114
                    trust_remote_code=self.model_config.trust_remote_code,
115
                )
116
117
118
119
120
121
122
123
124
125
126
            else:
                # Else, added to the list of workers.
                self.workers.append(worker)

        if self.driver_dummy_worker is None:
            raise ValueError(
                "Ray does not allocate any GPUs on the driver node. Consider "
                "adjusting the Ray placement group or running the driver on a "
                "GPU node.")

        # Get the set of GPU IDs used on each node.
127
128
        worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
                                                    use_dummy_driver=True)
129
130
131
132

        node_workers = defaultdict(list)
        node_gpus = defaultdict(list)

133
        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
134
135
136
137
138
            node_workers[node_id].append(i)
            node_gpus[node_id].extend(gpu_ids)
        for node_id, gpu_ids in node_gpus.items():
            node_gpus[node_id] = sorted(gpu_ids)

139
140
141
        VLLM_INSTANCE_ID = get_vllm_instance_id()

        # Set environment variables for the driver and workers.
142
143
144
145
146
147
148
149
        all_args_to_update_environment_variables = [({
            "CUDA_VISIBLE_DEVICES":
            ",".join(map(str, node_gpus[node_id])),
            "VLLM_INSTANCE_ID":
            VLLM_INSTANCE_ID,
            "VLLM_TRACE_FUNCTION":
            os.getenv("VLLM_TRACE_FUNCTION", "0"),
        }, ) for (node_id, _) in worker_node_and_gpu_ids]
150
151
        self._run_workers("update_environment_variables",
                          all_args=all_args_to_update_environment_variables)
152
153
154
155

        distributed_init_method = get_distributed_init_method(
            driver_ip, get_open_port())

156
157
158
159
160
        def collect_arg_helper_func(**kwargs):
            # avoid writing `{"name": value}` manually
            return kwargs

        # Initialize the actual workers inside worker wrapper.
161
162
        init_worker_all_kwargs = []
        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
163
            local_rank = node_workers[node_id].index(rank)
164
165
166
167
168
169
170
171
            init_worker_all_kwargs.append(
                collect_arg_helper_func(
                    model_config=self.model_config,
                    parallel_config=self.parallel_config,
                    scheduler_config=self.scheduler_config,
                    device_config=self.device_config,
                    cache_config=self.cache_config,
                    load_config=self.load_config,
172
173
174
                    local_rank=local_rank,
                    rank=rank,
                    distributed_init_method=distributed_init_method,
175
176
177
                    lora_config=self.lora_config,
                    vision_language_config=self.vision_language_config,
                    is_driver_worker=rank == 0,
178
                ))
179
        self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
180

181
        self._run_workers("init_device")
182
183
184
185
186
187
        self._run_workers(
            "load_model",
            max_concurrent_workers=self.parallel_config.
            max_parallel_loading_workers,
        )

188
    def determine_num_available_blocks(self) -> Tuple[int, int]:
189
        """Determine the number of available KV blocks.
190

191
192
193
        This invokes `determine_num_available_blocks` on each worker and takes
        the min of the results, guaranteeing that the selected cache sizes are
        compatible with all workers.
194

195
        Returns:
196
            - Tuple[num_gpu_blocks, num_cpu_blocks]
197
198
        """
        # Get the maximum number of blocks that can be allocated on GPU and CPU.
199
        num_blocks = self._run_workers("determine_num_available_blocks", )
200
201
202
203
204
205

        # Since we use a shared centralized controller, we take the minimum
        # number of blocks across all workers to make sure all the memory
        # operators can be applied to all workers.
        num_gpu_blocks = min(b[0] for b in num_blocks)
        num_cpu_blocks = min(b[1] for b in num_blocks)
206

207
        return num_gpu_blocks, num_cpu_blocks
208

209
210
211
212
213
214
215
216
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache in all workers.
        """

        # NOTE: We log here to avoid multiple logs when number of workers is
        # greater than one. We could log in the engine, but not all executors
        # have GPUs.
217
218
219
220
221
222
        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                    f"# CPU blocks: {num_cpu_blocks}")

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

223
224
225
        self._run_workers("initialize_cache",
                          num_gpu_blocks=num_gpu_blocks,
                          num_cpu_blocks=num_cpu_blocks)
226
227
228
229
230

    def execute_model(self,
                      seq_group_metadata_list: List[SequenceGroupMetadata],
                      blocks_to_swap_in: Dict[int, int],
                      blocks_to_swap_out: Dict[int, int],
231
232
                      blocks_to_copy: Dict[int, List[int]],
                      num_lookahead_slots: int = 0) -> SamplerOutput:
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        all_outputs = self._run_workers(
            "execute_model",
            driver_kwargs={
                "seq_group_metadata_list": seq_group_metadata_list,
                "blocks_to_swap_in": blocks_to_swap_in,
                "blocks_to_swap_out": blocks_to_swap_out,
                "blocks_to_copy": blocks_to_copy,
            },
            use_ray_compiled_dag=USE_RAY_COMPILED_DAG)

        # Only the driver worker returns the sampling results.
        output = all_outputs[0]
        return output

    def add_lora(self, lora_request: LoRARequest) -> bool:
        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
        return self._run_workers(
            "add_lora",
            lora_request=lora_request,
        )

    def remove_lora(self, lora_id: int) -> bool:
        assert lora_id > 0, "lora_id must be greater than 0."
        return self._run_workers(
            "remove_lora",
            lora_id=lora_id,
        )

261
    def list_loras(self) -> Set[int]:
262
263
264
265
266
267
        return self._run_workers("list_loras")

    def _run_workers(
        self,
        method: str,
        *args,
268
        driver_args: Optional[Tuple[Any, ...]] = None,
269
        driver_kwargs: Optional[Dict[str, Any]] = None,
270
        all_args: Optional[List[Tuple[Any, ...]]] = None,
271
272
        all_kwargs: Optional[List[Dict[str, Any]]] = None,
        use_dummy_driver: bool = False,
273
274
275
276
        max_concurrent_workers: Optional[int] = None,
        use_ray_compiled_dag: bool = False,
        **kwargs,
    ) -> Any:
277
278
279
280
281
282
283
284
        """Runs the given method on all workers. Can be used in the following
        ways:

        - args/kwargs: All workers share the same args/kwargs
        - args/kwargs and driver_args/driver_kwargs: Driver worker has
          different args
        - all_args/all_kwargs: args/kwargs for each worker are specified
          individually
285
        """
286
287
288
289
290

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

291
292
293
294
295
296
297
298
299
300
301
        if driver_args is None:
            driver_args = args if all_args is None else all_args[0]
        if driver_kwargs is None:
            driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]

        count = len(self.workers)
        all_worker_args = repeat(args, count) if all_args is None \
            else islice(all_args, 1, None)
        all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
            else islice(all_kwargs, 1, None)

302
303
304
        if use_ray_compiled_dag:
            # Right now, compiled DAG can only accept a single
            # input. TODO(sang): Fix it.
305
            assert self.forward_dag is not None
306
307
308
309
            output_channels = self.forward_dag.execute(1)
        else:
            # Start the ray workers first.
            ray_worker_outputs = [
310
311
312
                worker.execute_method.remote(method, *worker_args,
                                             **worker_kwargs)
                for (worker, worker_args, worker_kwargs
313
                     ) in zip(self.workers, all_worker_args, all_worker_kwargs)
314
315
316
            ]

        # Start the driver worker after all the ray workers.
317
318
        if not use_dummy_driver:
            driver_worker_output = self.driver_worker.execute_method(
319
                method, *driver_args, **driver_kwargs)
320
321
322
        else:
            driver_worker_output = ray.get(
                self.driver_dummy_worker.execute_method.remote(
323
                    method, *driver_args, **driver_kwargs))
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        # Get the results of the ray workers.
        if self.workers:
            if use_ray_compiled_dag:
                try:
                    ray_worker_outputs = [
                        pickle.loads(chan.begin_read())
                        for chan in output_channels
                    ]
                finally:
                    # Has to call end_read in order to reuse the DAG.
                    for chan in output_channels:
                        chan.end_read()
            else:
                ray_worker_outputs = ray.get(ray_worker_outputs)

        return [driver_worker_output] + ray_worker_outputs

    def _compiled_ray_dag(self):
        import pkg_resources
        required_version = "2.9"
        current_version = pkg_resources.get_distribution("ray").version
        if current_version < required_version:
            raise ValueError(f"Ray version {required_version} or greater is "
                             f"required, but found {current_version}")

349
        from ray.dag import InputNode, MultiOutputNode
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
        assert self.parallel_config.worker_use_ray

        # Right now, compiled DAG requires at least 1 arg. We send
        # a dummy value for now. It will be fixed soon.
        with InputNode() as input_data:
            forward_dag = MultiOutputNode([
                worker.execute_model_compiled_dag_remote.bind(input_data)
                for worker in self.workers
            ])
        return forward_dag.experimental_compile()

    def check_health(self) -> None:
        """Raises an error if engine is unhealthy."""
        self._check_if_any_actor_is_dead()

    def _check_if_any_actor_is_dead(self):
        if not self.workers:
            return

        dead_actors = []
        for actor in self.workers:
            actor_state = ray.state.actors(actor._ray_actor_id.hex())  # pylint: disable=protected-access
            if actor_state["State"] == "DEAD":
                dead_actors.append(actor)
        if dead_actors:
            raise RuntimeError("At least one Worker is dead. "
                               f"Dead Workers: {dead_actors}. ")


class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):

381
382
383
384
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.driver_executor = make_async(self.driver_worker.execute_method)

385
386
387
388
    async def _run_workers_async(
        self,
        method: str,
        *args,
389
        driver_args: Optional[Tuple[Any, ...]] = None,
390
391
392
393
394
395
396
397
398
399
400
        driver_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ) -> Any:
        """Runs the given method on all workers."""
        coros = []

        if driver_args is None:
            driver_args = args
        if driver_kwargs is None:
            driver_kwargs = kwargs

401
402
        coros.append(
            self.driver_executor(method, *driver_args, **driver_kwargs))
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424

        # Run the ray workers asynchronously.
        for worker in self.workers:
            coros.append(worker.execute_method.remote(method, *args, **kwargs))

        all_outputs = await asyncio.gather(*coros)
        return all_outputs

    async def execute_model_async(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
        blocks_to_copy: Dict[int, List[int]],
    ) -> SamplerOutput:
        all_outputs = await self._run_workers_async(
            "execute_model",
            driver_kwargs={
                "seq_group_metadata_list": seq_group_metadata_list,
                "blocks_to_swap_in": blocks_to_swap_in,
                "blocks_to_swap_out": blocks_to_swap_out,
                "blocks_to_copy": blocks_to_copy,
425
            })
426
427
428
429

        # Only the driver worker returns the sampling results.
        output = all_outputs[0]
        return output