ray_utils.py 23 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
import time
6
7
8
9
10
11
12
13
import queue
from collections import defaultdict, deque
from collections.abc import Callable, Sequence
from concurrent.futures import Future, InvalidStateError
from typing import TYPE_CHECKING, Union, Any
from threading import Thread
from enum import Enum, auto
from contextlib import suppress
14

15
import vllm.platforms
16
from vllm.config import ParallelConfig
17
from vllm.distributed import get_pp_group
18
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
19
from vllm.logger import init_logger
20
from vllm.platforms import current_platform
21
from vllm.sequence import IntermediateTensors
22
from vllm.utils.network_utils import get_ip
23
from vllm.v1.outputs import AsyncModelRunnerOutput
24
from vllm.v1.worker.worker_base import WorkerWrapperBase
25

26
if TYPE_CHECKING:
27
    from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
28
29
    from vllm.v1.outputs import ModelRunnerOutput

30
logger = init_logger(__name__)
31
PG_WAIT_TIMEOUT = 1800
32
33
34

try:
    import ray
35
    from ray.util import placement_group_table
36
    from ray.util.queue import Queue as RayQueue
37
    from ray.util.placement_group import PlacementGroup
38

39
40
41
42
43
    try:
        from ray._private.state import available_resources_per_node
    except ImportError:
        # Ray 2.9.x doesn't expose `available_resources_per_node`
        from ray._private.state import state as _state
44

45
        available_resources_per_node = _state._available_resources_per_node
46

47
    class RayWorkerWrapper(WorkerWrapperBase):
48
        """Ray wrapper for vllm.worker.Worker, allowing Worker to be
49
        lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""
50

51
        def __init__(self, use_async_scheduling: bool, response_mq: RayQueue, *args, **kwargs) -> None: # type: ignore
52
            super().__init__(*args, **kwargs)
53
54
55
56
57
            # Since the compiled DAG runs a main execution
            # in a different thread that calls cuda.set_device.
            # The flag indicates is set_device is called on
            # that thread.
            self.compiled_dag_cuda_device_set = False
58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
            # async scheduling
            self.use_async_scheduling = use_async_scheduling
            self.worker_response_mq = response_mq
            if self.use_async_scheduling:
                self.async_output_queue: queue.Queue = queue.Queue()
                self.async_output_copy_thread = Thread(
                    target=self.async_output_busy_loop,
                    daemon=True,
                    name="WorkerAsyncOutputCopy",
                )
                self.async_output_copy_thread.start()

        class ResponseStatus(Enum):
            SUCCESS = auto()
            FAILURE = auto()

75
76
77
        def get_node_ip(self) -> str:
            return get_ip()

78
        def get_node_and_gpu_ids(self) -> tuple[str, list[int]]:
79
            node_id = ray.get_runtime_context().get_node_id()
80
            device_key = vllm.platforms.current_platform.ray_device_key
81
            if not device_key:
82
83
84
85
86
                raise RuntimeError(
                    "current platform %s does not support ray.",
                    vllm.platforms.current_platform.device_name,
                )
            gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key]
87
88
            return node_id, gpu_ids

89
90
91
92
93
94
95
        def setup_device_if_necessary(self):
            # TODO(swang): This is needed right now because Ray CG executes
            # on a background thread, so we need to reset torch's current
            # device.
            # We can remove this API after it is fixed in compiled graph.
            assert self.worker is not None, "Worker is not initialized"
            if not self.compiled_dag_cuda_device_set:
96
97
98
99
                if current_platform.is_tpu():
                    # Not needed
                    pass
                else:
100
                    assert self.worker.device is not None
101
                    current_platform.set_device(self.worker.device)
102

103
104
                self.compiled_dag_cuda_device_set = True

105
        def execute_model_ray(
106
            self,
107
            execute_model_input: tuple["SchedulerOutput", "GrammarOutput"]
108
            | tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"]
109
            | tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
110
        ) -> Union[
111
            "ModelRunnerOutput",
112
            "AsyncModelRunnerOutput", 
113
            tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
114
        ]:
115
            # This method is used by Ray Compiled Graph to execute the model,
116
            # and it needs a special logic of self.setup_device_if_necessary()
117
118
            self.setup_device_if_necessary()
            assert self.worker is not None, "Worker is not initialized"
119
120
121
122
            if len(execute_model_input) == 3:
                scheduler_output, grammar_output, intermediate_tensors = (
                    execute_model_input
                )
123
            else:
124
125
                scheduler_output, grammar_output = execute_model_input
                intermediate_tensors = None
126
            assert self.worker.model_runner is not None
127
            output = self.worker.model_runner.execute_model(
128
129
                scheduler_output, intermediate_tensors
            )
130
            if self._is_intermediate_tensors(output):
131
132
133
                return scheduler_output, grammar_output, output

            if isinstance(output, AsyncModelRunnerOutput):
134
135
136
137
                if not self.use_async_scheduling:
                    output = output.get_output()
                else:
                    output = output.get_output_async()
138
            if not get_pp_group().is_last_rank:
139
140
141
                # Case where there are no scheduled requests
                # but may still be finished requests.
                assert not output or not output.req_ids
142
143
144
                output = scheduler_output, grammar_output, None
            elif output is None:
                output = self.worker.model_runner.sample_tokens(grammar_output)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
                
                if self.use_async_scheduling:
                    if self.output_rank == -1 or self.rpc_rank == self.output_rank:
                        self.handle_output(output)
                    
                    if isinstance(output, AsyncModelRunnerOutput):
                        output = output.get_output_async()
                else:
                    # Ensure outputs crossing Ray compiled DAG are serializable.
                    # AsyncModelRunnerOutput holds CUDA events and cannot be
                    # pickled.
                    if isinstance(output, AsyncModelRunnerOutput):
                        output = output.get_output()
            else:
                if self.use_async_scheduling and (self.output_rank == -1 or self.rpc_rank == self.output_rank):
                    self.handle_output(output)
161
162
            return output

163
        def override_env_vars(self, vars: dict[str, str]):
164
165
            os.environ.update(vars)

166
167
        def _is_intermediate_tensors(self, output) -> bool:
            return isinstance(output, IntermediateTensors)
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        
        def enqueue_output(self, output: Any):
            """Prepares output from the worker and enqueues it to the
            worker_response_mq. If the output is an Exception, it is
            converted to a FAILURE response.
            """
            import os
            import threading
            if isinstance(output, AsyncModelRunnerOutput):
                output = output.get_output()

            if isinstance(output, Exception):
                result = (RayWorkerWrapper.ResponseStatus.FAILURE, str(output))
            else:
                result = (RayWorkerWrapper.ResponseStatus.SUCCESS, output)

            if (response_mq := self.worker_response_mq) is not None:
                response_mq.put(result)

        def handle_output(self, output: Any):
            """Handles output from the worker. If async scheduling is enabled,
            it is passed to the async_output_busy_loop thread. Otherwise, it is
            enqueued directly to the worker_response_mq.
            """
            if self.use_async_scheduling:
                self.async_output_queue.put(output)
            else:
                self.enqueue_output(output)
        
        def async_output_busy_loop(self):
            """Entrypoint for the thread which handles outputs asynchronously."""
            while True:
                output = self.async_output_queue.get()
                self.enqueue_output(output)

203

204
205
    ray_import_err = None

206
except ImportError as e:
207
    ray = None  # type: ignore
208
209
210
    # only capture string to avoid variable references in the traceback that can
    # prevent garbage collection in some cases
    ray_import_err = str(e)
211
    RayWorkerWrapper = None  # type: ignore
212
213


214
215
216
217
218
219
220
221
222
class FutureWrapper(Future):
    """A wrapper around Ray output reference to meet the interface
    of .execute_model(): The top level (core busy loop) expects .result() api
    to block and return a single output.

    If aggregator is provided, the outputs from all workers are aggregated upon
    the result() call. If not only the first worker's output is returned.
    """

223
    def __init__(self, ref_or_refs, aggregator: KVOutputAggregator | None = None):
224
        super().__init__()
225
        self.ref_or_refs = ref_or_refs
226
227
228
        self.aggregator = aggregator

    def result(self, timeout=None):
229
        outputs = ray.get(self.ref_or_refs, timeout=timeout)
230
        if self.aggregator is None:
231
            return outputs
232
233
234

        return self.aggregator.aggregate(outputs, output_rank=0)

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
class NonBlockFutureWrapper(Future):
    def __init__(
        self,
        futures_queue: deque[tuple["FutureWrapper", Callable]],
        aggregate: Callable = lambda x: x,
    ):
        self.futures_queue = futures_queue
        self.aggregate = aggregate
        super().__init__()

    def result(self, timeout=None):
        if timeout is not None:
            raise RuntimeError("timeout not implemented")
        # Drain any futures ahead of us in the queue.
        while not self.done():
            future, get_response = self.futures_queue.pop()
            future.wait_for_response(get_response)
        return super().result()

    def wait_for_response(self, get_response: Callable):
        try:
            response = self.aggregate(get_response())
            with suppress(InvalidStateError):
                self.set_result(response)
        except Exception as e:
            with suppress(InvalidStateError):
                self.set_exception(e)

263

264
265
266
267
268
269
270
271
def ray_is_available() -> bool:
    """Returns True if Ray is available."""
    return ray is not None


def assert_ray_available():
    """Raise an exception if Ray is not available."""
    if ray is None:
272
273
274
275
        raise ValueError(
            f"Failed to import Ray: {ray_import_err}."
            "Please install Ray with `pip install ray`."
        )
276
277


278
279
280
def _verify_bundles(
    placement_group: "PlacementGroup", parallel_config: ParallelConfig, device_str: str
):
281
282
283
284
285
286
287
    """Verify a given placement group has bundles located in the right place.

    There are 2 rules.
    - Warn if all tensor parallel workers cannot fit in a single node.
    - Fail if driver node is not included in a placement group.
    """
    assert ray.is_initialized(), (
288
289
        "Ray is not initialized although distributed-executor-backend is ray."
    )
290
291
292
293
294
295
    pg_data = placement_group_table(placement_group)
    # bundle_idx -> node_id
    bundle_to_node_ids = pg_data["bundles_to_node_id"]
    # bundle_idx -> bundle (e.g., {"GPU": 1})
    bundles = pg_data["bundles"]
    # node_id -> List of bundle (e.g., {"GPU": 1})
296
    node_id_to_bundle: dict[str, list[dict[str, float]]] = defaultdict(list)
297
298
299
300
301
302
303
304
305
306
307

    for bundle_idx, node_id in bundle_to_node_ids.items():
        node_id_to_bundle[node_id].append(bundles[bundle_idx])
    driver_node_id = ray.get_runtime_context().get_node_id()

    if driver_node_id not in node_id_to_bundle:
        raise RuntimeError(
            f"driver node id {driver_node_id} is not included in a placement "
            f"group {placement_group.id}. Node id -> bundles "
            f"{node_id_to_bundle}. "
            "You don't have enough GPUs available in a current node. Check "
308
309
310
            "`ray status` and `ray list nodes` to see if you have available "
            "GPUs in a node `{driver_node_id}` before starting an vLLM engine."
        )
311
312
313
314
315
316
317
318
319
320
321

    for node_id, bundles in node_id_to_bundle.items():
        if len(bundles) < parallel_config.tensor_parallel_size:
            logger.warning(
                "tensor_parallel_size=%d "
                "is bigger than a reserved number of %ss (%d "
                "%ss) in a node %s. Tensor parallel workers can be "
                "spread out to 2+ nodes which can degrade the performance "
                "unless you have fast interconnect across nodes, like "
                "Infiniband. To resolve this issue, make sure you have more "
                "than %d GPUs available at each node.",
322
323
324
325
326
327
328
                parallel_config.tensor_parallel_size,
                device_str,
                len(bundles),
                device_str,
                node_id,
                parallel_config.tensor_parallel_size,
            )
329
330
331
332
333
334
335
336
337
338


def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
    """Wait until a placement group is ready.

    It prints the informative log messages if the placement group is
    not created within time.

    """
    # Wait until PG is ready - this will block until all
339
    # requested resources are available, and will time out
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    # if they cannot be provisioned.
    placement_group_specs = current_placement_group.bundle_specs

    s = time.time()
    pg_ready_ref = current_placement_group.ready()
    wait_interval = 10
    while time.time() - s < PG_WAIT_TIMEOUT:
        ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
        if len(ready) > 0:
            break

        # Exponential backoff for warning print.
        wait_interval *= 2
        logger.info(
            "Waiting for creating a placement group of specs for "
355
356
            "%d seconds. specs=%s. Check `ray status` and "
            "`ray list nodes` to see if you have enough resources,"
357
358
359
            " and make sure the IP addresses used by ray cluster"
            " are the same as VLLM_HOST_IP environment variable"
            " specified in each node if you are running on a multi-node.",
360
361
362
            int(time.time() - s),
            placement_group_specs,
        )
363
364
365
366

    try:
        ray.get(pg_ready_ref, timeout=0)
    except ray.exceptions.GetTimeoutError:
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        # Provide more helpful error message when GPU count is exceeded
        total_gpu_required = sum(spec.get("GPU", 0) for spec in placement_group_specs)
        # If more than one GPU is required for the placement group, provide a
        # more specific error message.
        # We use >1 here because multi-GPU (tensor parallel) jobs are more
        # likely to fail due to insufficient cluster resources, and users may
        # need to adjust tensor_parallel_size to fit available GPUs.
        if total_gpu_required > 1:
            raise ValueError(
                f"Cannot provide a placement group requiring "
                f"{total_gpu_required} GPUs "
                f"(placement_group_specs={placement_group_specs}) within "
                f"{PG_WAIT_TIMEOUT} seconds.\n"
                f"Tensor parallel size may exceed available GPUs in your "
                f"cluster. Check resources with `ray status` and "
                f"`ray list nodes`.\n"
                f"If running on K8s with limited GPUs, consider reducing "
                f"--tensor-parallel-size to match available GPU resources."
            ) from None
        else:
            raise ValueError(
                "Cannot provide a placement group of "
                f"{placement_group_specs=} within "
                f"{PG_WAIT_TIMEOUT} seconds. See "
                "`ray status` and `ray list nodes` to make sure the cluster "
                "has enough resources."
            ) from None
394
395
396
397
398
399
400
401
402
403
404
405
406
407


def _wait_until_pg_removed(current_placement_group: "PlacementGroup"):
    ray.util.remove_placement_group(current_placement_group)
    s = time.time()
    wait_interval = 10
    while time.time() - s < PG_WAIT_TIMEOUT:
        pg = ray.util.get_current_placement_group()
        if pg is None:
            break

        # Exponential backoff for warning print.
        wait_interval *= 2
        logger.info(
408
409
410
            "Waiting for removing a placement group of specs for %d seconds.",
            int(time.time() - s),
        )
411
412
413
        time.sleep(wait_interval)


414
def initialize_ray_cluster(
415
    parallel_config: ParallelConfig,
416
    ray_address: str | None = None,
417
418
419
420
421
422
):
    """Initialize the distributed cluster with Ray.

    it will connect to the Ray cluster and create a placement group
    for the workers, which includes the specification of the resources
    for each distributed worker.
423
424
425

    Args:
        parallel_config: The configurations for parallel execution.
Zhuohan Li's avatar
Zhuohan Li committed
426
        ray_address: The address of the Ray cluster. If None, uses
427
428
            the default Ray cluster address.
    """
429
    assert_ray_available()
430
    from vllm.platforms import current_platform
431

432
433
    # Prevalidate GPU requirements before Ray processing
    if current_platform.is_cuda() and parallel_config.world_size > 1:
434
        from vllm.utils.torch_utils import cuda_device_count_stateless
435
436
437
438
439
440
441
442
443
444
445
446
447
448

        available_gpus = cuda_device_count_stateless()
        if parallel_config.world_size > available_gpus:
            logger.warning(
                "Tensor parallel size (%d) exceeds available GPUs (%d). "
                "This may result in Ray placement group allocation failures. "
                "Consider reducing tensor_parallel_size to %d or less, "
                "or ensure your Ray cluster has %d GPUs available.",
                parallel_config.world_size,
                available_gpus,
                available_gpus,
                parallel_config.world_size,
            )

449
450
451
    if ray.is_initialized():
        logger.info("Ray is already initialized. Skipping Ray initialization.")
    elif current_platform.is_rocm() or current_platform.is_xpu():
452
453
        # Try to connect existing ray instance and create a new one if not found
        try:
454
            ray.init("auto")
455
456
457
        except ConnectionError:
            logger.warning(
                "No existing RAY instance detected. "
458
459
460
461
462
463
464
                "A new instance will be launched with current node resources."
            )
            ray.init(
                address=ray_address,
                num_gpus=parallel_config.world_size,
                runtime_env=parallel_config.ray_runtime_env,
            )
465
    else:
466
        ray.init(address=ray_address, runtime_env=parallel_config.ray_runtime_env)
467

468
469
470
    device_str = current_platform.ray_device_key
    if not device_str:
        raise ValueError(
471
472
            f"current platform {current_platform.device_name} does not support ray."
        )
473

474
475
476
477
478
479
    # Create or get the placement group for worker processes
    if parallel_config.placement_group:
        current_placement_group = parallel_config.placement_group
    else:
        current_placement_group = ray.util.get_current_placement_group()

480
    if current_placement_group:
481
482
        logger.info("Using the existing placement group")

483
484
485
        # We are in a placement group
        bundles = current_placement_group.bundle_specs
        # Verify that we can use the placement group.
486
        device_bundles = 0
487
        for bundle in bundles:
488
489
            bundle_devices = bundle.get(device_str, 0)
            if bundle_devices > 1:
490
                raise ValueError(
491
492
                    f"Placement group bundle cannot have more than 1 {device_str}."
                )
493
494
495
            if bundle_devices:
                device_bundles += 1
        if parallel_config.world_size > device_bundles:
496
            raise ValueError(
497
                f"The number of required {device_str}s exceeds the total "
498
                f"number of available {device_str}s in the placement group. "
499
                f"Required number of devices: {parallel_config.world_size}. "
500
501
                f"Total number of devices: {device_bundles}."
            )
502
    else:
503
        logger.info("No current placement group found. Creating a new placement group.")
504
        num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
505
506
507
        # Log a warning message and delay resource allocation failure response.
        # Avoid immediate rejection to allow user-initiated placement group
        # created and wait cluster to be ready
508
        if parallel_config.world_size > num_devices_in_cluster:
509
510
            logger.warning(
                "The number of required %ss exceeds the total "
511
512
513
514
                "number of available %ss in the placement group.",
                device_str,
                device_str,
            )
515
        # Create a new placement group
516
        placement_group_specs: list[dict[str, float]] = [
517
518
            {device_str: 1.0} for _ in range(parallel_config.world_size)
        ]
519
520
521
522
523
524
525
526
527
528
529
530

        # vLLM engine is also a worker to execute model with an accelerator,
        # so it requires to have the device in a current node. Check if
        # the current node has at least one device.
        current_ip = get_ip()
        current_node_id = ray.get_runtime_context().get_node_id()
        current_node_resource = available_resources_per_node()[current_node_id]
        if current_node_resource.get(device_str, 0) < 1:
            raise ValueError(
                f"Current node has no {device_str} available. "
                f"{current_node_resource=}. vLLM engine cannot start without "
                f"{device_str}. Make sure you have at least 1 {device_str} "
531
532
                f"available in a node {current_node_id=} {current_ip=}."
            )
533
534
535
536
537
        # This way, at least bundle is required to be created in a current
        # node.
        placement_group_specs[0][f"node:{current_ip}"] = 0.001

        # By default, Ray packs resources as much as possible.
538
        current_placement_group = ray.util.placement_group(
539
540
            placement_group_specs, strategy="PACK"
        )
541
        _wait_until_pg_ready(current_placement_group)
542

543
544
    assert current_placement_group is not None
    _verify_bundles(current_placement_group, parallel_config, device_str)
545
546
    # Set the placement group in the parallel config
    parallel_config.placement_group = current_placement_group
547
548
549
550


def get_num_tpu_nodes() -> int:
    from ray._private.accelerators import TPUAcceleratorManager
551

552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
    cluster_resources = ray.cluster_resources()
    total_tpus = int(cluster_resources["TPU"])
    tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
    assert total_tpus % tpus_per_node == 0
    return total_tpus // tpus_per_node


def get_num_nodes_in_placement_group() -> int:
    pg_table = ray.util.placement_group_table()
    current_pg = ray.util.get_current_placement_group()
    num_nodes = 0

    if current_pg:
        nodes_in_pg = set()
        for pg_key, pg in pg_table.items():
            if pg_key == current_pg.id.hex():
                for _, node in pg["bundles_to_node_id"].items():
                    nodes_in_pg.add(node)
        num_nodes = len(nodes_in_pg)

    return num_nodes