ray_utils.py 17.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
6
import time
from collections import defaultdict
7
from typing import TYPE_CHECKING, Union
8

9
10
import msgspec

11
import vllm.platforms
12
from vllm.config import ParallelConfig
13
from vllm.distributed import get_pp_group
14
from vllm.executor.msgspec_utils import decode_hook, encode_hook
15
from vllm.logger import init_logger
16
from vllm.platforms import current_platform
17
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
18
from vllm.utils import get_ip
19
from vllm.v1.outputs import AsyncModelRunnerOutput
20
from vllm.v1.worker.worker_base import WorkerWrapperBase
21

22
if TYPE_CHECKING:
23
    from vllm.v1.core.sched.output import SchedulerOutput
24
25
    from vllm.v1.outputs import ModelRunnerOutput

26
logger = init_logger(__name__)
27
PG_WAIT_TIMEOUT = 1800
28
29
30

try:
    import ray
31
32
    from ray.util import placement_group_table
    from ray.util.placement_group import PlacementGroup
33

34
35
36
37
38
    try:
        from ray._private.state import available_resources_per_node
    except ImportError:
        # Ray 2.9.x doesn't expose `available_resources_per_node`
        from ray._private.state import state as _state
39

40
        available_resources_per_node = _state._available_resources_per_node
41

42
    class RayWorkerWrapper(WorkerWrapperBase):
43
        """Ray wrapper for vllm.worker.Worker, allowing Worker to be
44
        lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""
45

46
47
        def __init__(self, *args, **kwargs) -> None:
            super().__init__(*args, **kwargs)
48
49
50
51
52
            # Since the compiled DAG runs a main execution
            # in a different thread that calls cuda.set_device.
            # The flag indicates is set_device is called on
            # that thread.
            self.compiled_dag_cuda_device_set = False
53

54
55
56
            self.input_decoder = msgspec.msgpack.Decoder(
                ExecuteModelRequest, dec_hook=decode_hook
            )
57
58
            self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)

59
60
61
        def get_node_ip(self) -> str:
            return get_ip()

62
        def get_node_and_gpu_ids(self) -> tuple[str, list[int]]:
63
            node_id = ray.get_runtime_context().get_node_id()
64
            device_key = vllm.platforms.current_platform.ray_device_key
65
            if not device_key:
66
67
68
69
70
                raise RuntimeError(
                    "current platform %s does not support ray.",
                    vllm.platforms.current_platform.device_name,
                )
            gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key]
71
72
            return node_id, gpu_ids

73
        def execute_model_spmd(
74
            self,
75
            req_or_tuple: bytes | tuple[bytes, IntermediateTensors | None],
76
        ) -> bytes:
77
78
79
80
            """Execute model in SPMD fashion: used only when SPMD worker and
            compiled DAG are both enabled.

            Args:
81
82
83
84
                req_or_tuple: A request or a tuple containing the
                    request and intermediate tensors. Intermediate tensors are
                    None unless if it is provided because it is > 0 pipeline
                    stage. The request is serialized by msgspec.
85
            """
86
87
88
89
90
91
92
            if isinstance(req_or_tuple, bytes):
                serialized_req, intermediate_tensors = req_or_tuple, None
            else:
                serialized_req, intermediate_tensors = req_or_tuple

            execute_model_req = self.input_decoder.decode(serialized_req)

93
94
95
            # TODO(swang): This is needed right now because Ray Compiled Graph
            # executes on a background thread, so we need to reset torch's
            # current device.
96
            if not self.compiled_dag_cuda_device_set:
97
                current_platform.set_device(self.worker.device)
98
99
                self.compiled_dag_cuda_device_set = True

100
101
102
            output = self.worker._execute_model_spmd(
                execute_model_req, intermediate_tensors
            )
103
            # Pipeline model request and output to the next pipeline stage.
104
            if isinstance(output, IntermediateTensors):
105
106
107
108
                output = serialized_req, output
            else:
                output = self.output_encoder.encode(output)

109
            return output
110

111
112
113
114
115
116
117
        def setup_device_if_necessary(self):
            # TODO(swang): This is needed right now because Ray CG executes
            # on a background thread, so we need to reset torch's current
            # device.
            # We can remove this API after it is fixed in compiled graph.
            assert self.worker is not None, "Worker is not initialized"
            if not self.compiled_dag_cuda_device_set:
118
119
120
121
                if current_platform.is_tpu():
                    # Not needed
                    pass
                else:
122
                    current_platform.set_device(self.worker.device)
123

124
125
                self.compiled_dag_cuda_device_set = True

126
        def execute_model_ray(
127
            self,
128
            scheduler_output: Union[
129
                "SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
130
131
            ],
        ) -> Union[
132
            "ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
133
        ]:
134
            # This method is used by Ray Compiled Graph to execute the model,
135
            # and it needs a special logic of self.setup_device_if_necessary()
136
137
            self.setup_device_if_necessary()
            assert self.worker is not None, "Worker is not initialized"
138
139
140
141
142
            if isinstance(scheduler_output, tuple):
                scheduler_output, intermediate_tensors = scheduler_output
            else:
                scheduler_output, intermediate_tensors = scheduler_output, None
            output = self.worker.model_runner.execute_model(
143
144
                scheduler_output, intermediate_tensors
            )
145
146
            if isinstance(output, IntermediateTensors):
                output = scheduler_output, output
147
148
149
150
151
            elif not get_pp_group().is_last_rank:
                # Case where there are no scheduled requests
                # but may still be finished requests.
                assert not output or not output.req_ids
                output = scheduler_output, None
152
153
154
155
156
            # Ensure outputs crossing Ray compiled DAG are serializable.
            # AsyncModelRunnerOutput holds CUDA events and cannot be
            # pickled.
            if isinstance(output, AsyncModelRunnerOutput):
                output = output.get_output()
157
158
            return output

159
        def override_env_vars(self, vars: dict[str, str]):
160
161
            os.environ.update(vars)

162
163
    ray_import_err = None

164
except ImportError as e:
165
    ray = None  # type: ignore
166
167
168
    # only capture string to avoid variable references in the traceback that can
    # prevent garbage collection in some cases
    ray_import_err = str(e)
169
    RayWorkerWrapper = None  # type: ignore
170
171


172
173
174
175
176
177
178
179
def ray_is_available() -> bool:
    """Returns True if Ray is available."""
    return ray is not None


def assert_ray_available():
    """Raise an exception if Ray is not available."""
    if ray is None:
180
181
182
183
        raise ValueError(
            f"Failed to import Ray: {ray_import_err}."
            "Please install Ray with `pip install ray`."
        )
184
185


186
187
188
def _verify_bundles(
    placement_group: "PlacementGroup", parallel_config: ParallelConfig, device_str: str
):
189
190
191
192
193
194
195
    """Verify a given placement group has bundles located in the right place.

    There are 2 rules.
    - Warn if all tensor parallel workers cannot fit in a single node.
    - Fail if driver node is not included in a placement group.
    """
    assert ray.is_initialized(), (
196
197
        "Ray is not initialized although distributed-executor-backend is ray."
    )
198
199
200
201
202
203
    pg_data = placement_group_table(placement_group)
    # bundle_idx -> node_id
    bundle_to_node_ids = pg_data["bundles_to_node_id"]
    # bundle_idx -> bundle (e.g., {"GPU": 1})
    bundles = pg_data["bundles"]
    # node_id -> List of bundle (e.g., {"GPU": 1})
204
    node_id_to_bundle: dict[str, list[dict[str, float]]] = defaultdict(list)
205
206
207
208
209
210
211
212
213
214
215

    for bundle_idx, node_id in bundle_to_node_ids.items():
        node_id_to_bundle[node_id].append(bundles[bundle_idx])
    driver_node_id = ray.get_runtime_context().get_node_id()

    if driver_node_id not in node_id_to_bundle:
        raise RuntimeError(
            f"driver node id {driver_node_id} is not included in a placement "
            f"group {placement_group.id}. Node id -> bundles "
            f"{node_id_to_bundle}. "
            "You don't have enough GPUs available in a current node. Check "
216
217
218
            "`ray status` and `ray list nodes` to see if you have available "
            "GPUs in a node `{driver_node_id}` before starting an vLLM engine."
        )
219
220
221
222
223
224
225
226
227
228
229

    for node_id, bundles in node_id_to_bundle.items():
        if len(bundles) < parallel_config.tensor_parallel_size:
            logger.warning(
                "tensor_parallel_size=%d "
                "is bigger than a reserved number of %ss (%d "
                "%ss) in a node %s. Tensor parallel workers can be "
                "spread out to 2+ nodes which can degrade the performance "
                "unless you have fast interconnect across nodes, like "
                "Infiniband. To resolve this issue, make sure you have more "
                "than %d GPUs available at each node.",
230
231
232
233
234
235
236
                parallel_config.tensor_parallel_size,
                device_str,
                len(bundles),
                device_str,
                node_id,
                parallel_config.tensor_parallel_size,
            )
237
238
239
240
241
242
243
244
245
246


def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
    """Wait until a placement group is ready.

    It prints the informative log messages if the placement group is
    not created within time.

    """
    # Wait until PG is ready - this will block until all
247
    # requested resources are available, and will time out
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    # if they cannot be provisioned.
    placement_group_specs = current_placement_group.bundle_specs

    s = time.time()
    pg_ready_ref = current_placement_group.ready()
    wait_interval = 10
    while time.time() - s < PG_WAIT_TIMEOUT:
        ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
        if len(ready) > 0:
            break

        # Exponential backoff for warning print.
        wait_interval *= 2
        logger.info(
            "Waiting for creating a placement group of specs for "
263
264
            "%d seconds. specs=%s. Check `ray status` and "
            "`ray list nodes` to see if you have enough resources,"
265
266
267
            " and make sure the IP addresses used by ray cluster"
            " are the same as VLLM_HOST_IP environment variable"
            " specified in each node if you are running on a multi-node.",
268
269
270
            int(time.time() - s),
            placement_group_specs,
        )
271
272
273
274
275
276
277

    try:
        ray.get(pg_ready_ref, timeout=0)
    except ray.exceptions.GetTimeoutError:
        raise ValueError(
            "Cannot provide a placement group of "
            f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
278
            "`ray status` and `ray list nodes` to make sure the cluster has "
279
280
            "enough resources."
        ) from None
281
282
283
284
285
286
287
288
289
290
291
292
293
294


def _wait_until_pg_removed(current_placement_group: "PlacementGroup"):
    ray.util.remove_placement_group(current_placement_group)
    s = time.time()
    wait_interval = 10
    while time.time() - s < PG_WAIT_TIMEOUT:
        pg = ray.util.get_current_placement_group()
        if pg is None:
            break

        # Exponential backoff for warning print.
        wait_interval *= 2
        logger.info(
295
296
297
            "Waiting for removing a placement group of specs for %d seconds.",
            int(time.time() - s),
        )
298
299
300
        time.sleep(wait_interval)


301
def initialize_ray_cluster(
302
    parallel_config: ParallelConfig,
303
    ray_address: str | None = None,
304
305
306
307
308
309
):
    """Initialize the distributed cluster with Ray.

    it will connect to the Ray cluster and create a placement group
    for the workers, which includes the specification of the resources
    for each distributed worker.
310
311
312

    Args:
        parallel_config: The configurations for parallel execution.
Zhuohan Li's avatar
Zhuohan Li committed
313
        ray_address: The address of the Ray cluster. If None, uses
314
315
            the default Ray cluster address.
    """
316
    assert_ray_available()
317
    from vllm.platforms import current_platform
318

319
320
321
    if ray.is_initialized():
        logger.info("Ray is already initialized. Skipping Ray initialization.")
    elif current_platform.is_rocm() or current_platform.is_xpu():
322
323
        # Try to connect existing ray instance and create a new one if not found
        try:
324
            ray.init("auto")
325
326
327
        except ConnectionError:
            logger.warning(
                "No existing RAY instance detected. "
328
329
330
331
332
333
334
                "A new instance will be launched with current node resources."
            )
            ray.init(
                address=ray_address,
                num_gpus=parallel_config.world_size,
                runtime_env=parallel_config.ray_runtime_env,
            )
335
    else:
336
        ray.init(address=ray_address, runtime_env=parallel_config.ray_runtime_env)
337

338
339
340
    device_str = current_platform.ray_device_key
    if not device_str:
        raise ValueError(
341
342
            f"current platform {current_platform.device_name} does not support ray."
        )
343

344
345
346
347
348
349
    # Create or get the placement group for worker processes
    if parallel_config.placement_group:
        current_placement_group = parallel_config.placement_group
    else:
        current_placement_group = ray.util.get_current_placement_group()

350
    if current_placement_group:
351
352
        logger.info("Using the existing placement group")

353
354
355
        # We are in a placement group
        bundles = current_placement_group.bundle_specs
        # Verify that we can use the placement group.
356
        device_bundles = 0
357
        for bundle in bundles:
358
359
            bundle_devices = bundle.get(device_str, 0)
            if bundle_devices > 1:
360
                raise ValueError(
361
362
                    f"Placement group bundle cannot have more than 1 {device_str}."
                )
363
364
365
            if bundle_devices:
                device_bundles += 1
        if parallel_config.world_size > device_bundles:
366
            raise ValueError(
367
                f"The number of required {device_str}s exceeds the total "
368
                f"number of available {device_str}s in the placement group. "
369
                f"Required number of devices: {parallel_config.world_size}. "
370
371
                f"Total number of devices: {device_bundles}."
            )
372
    else:
373
        logger.info("No current placement group found. Creating a new placement group.")
374
        num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
375
376
377
        # Log a warning message and delay resource allocation failure response.
        # Avoid immediate rejection to allow user-initiated placement group
        # created and wait cluster to be ready
378
        if parallel_config.world_size > num_devices_in_cluster:
379
380
            logger.warning(
                "The number of required %ss exceeds the total "
381
382
383
384
                "number of available %ss in the placement group.",
                device_str,
                device_str,
            )
385
        # Create a new placement group
386
        placement_group_specs: list[dict[str, float]] = [
387
388
            {device_str: 1.0} for _ in range(parallel_config.world_size)
        ]
389
390
391
392
393
394
395
396
397
398
399
400

        # vLLM engine is also a worker to execute model with an accelerator,
        # so it requires to have the device in a current node. Check if
        # the current node has at least one device.
        current_ip = get_ip()
        current_node_id = ray.get_runtime_context().get_node_id()
        current_node_resource = available_resources_per_node()[current_node_id]
        if current_node_resource.get(device_str, 0) < 1:
            raise ValueError(
                f"Current node has no {device_str} available. "
                f"{current_node_resource=}. vLLM engine cannot start without "
                f"{device_str}. Make sure you have at least 1 {device_str} "
401
402
                f"available in a node {current_node_id=} {current_ip=}."
            )
403
404
405
406
407
        # This way, at least bundle is required to be created in a current
        # node.
        placement_group_specs[0][f"node:{current_ip}"] = 0.001

        # By default, Ray packs resources as much as possible.
408
        current_placement_group = ray.util.placement_group(
409
410
            placement_group_specs, strategy="PACK"
        )
411
        _wait_until_pg_ready(current_placement_group)
412

413
414
    assert current_placement_group is not None
    _verify_bundles(current_placement_group, parallel_config, device_str)
415
416
    # Set the placement group in the parallel config
    parallel_config.placement_group = current_placement_group
417
418
419
420


def get_num_tpu_nodes() -> int:
    from ray._private.accelerators import TPUAcceleratorManager
421

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    cluster_resources = ray.cluster_resources()
    total_tpus = int(cluster_resources["TPU"])
    tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
    assert total_tpus % tpus_per_node == 0
    return total_tpus // tpus_per_node


def get_num_nodes_in_placement_group() -> int:
    pg_table = ray.util.placement_group_table()
    current_pg = ray.util.get_current_placement_group()
    num_nodes = 0

    if current_pg:
        nodes_in_pg = set()
        for pg_key, pg in pg_table.items():
            if pg_key == current_pg.id.hex():
                for _, node in pg["bundles_to_node_id"].items():
                    nodes_in_pg.add(node)
        num_nodes = len(nodes_in_pg)

    return num_nodes