ray_utils.py 15.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
5
import time
from collections import defaultdict
6
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
7

8
9
import msgspec

10
import vllm.platforms
11
from vllm.config import ParallelConfig
12
from vllm.executor.msgspec_utils import decode_hook, encode_hook
13
from vllm.logger import init_logger
14
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
15
from vllm.utils import get_ip
16
from vllm.worker.worker_base import WorkerWrapperBase
17

18
19
20
21
if TYPE_CHECKING:
    from vllm.v1.core.scheduler import SchedulerOutput
    from vllm.v1.outputs import ModelRunnerOutput

22
logger = init_logger(__name__)
23
PG_WAIT_TIMEOUT = 1800
24
25
26

try:
    import ray
27
28
    from ray.util import placement_group_table
    from ray.util.placement_group import PlacementGroup
29
30
31
32
33
34
    try:
        from ray._private.state import available_resources_per_node
    except ImportError:
        # Ray 2.9.x doesn't expose `available_resources_per_node`
        from ray._private.state import state as _state
        available_resources_per_node = _state._available_resources_per_node
35

36
    class RayWorkerWrapper(WorkerWrapperBase):
37
38
39
        """Ray wrapper for vllm.worker.Worker, allowing Worker to be
        lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""

40
41
        def __init__(self, *args, **kwargs) -> None:
            super().__init__(*args, **kwargs)
42
43
44
45
46
            # Since the compiled DAG runs a main execution
            # in a different thread that calls cuda.set_device.
            # The flag indicates is set_device is called on
            # that thread.
            self.compiled_dag_cuda_device_set = False
47

48
49
50
51
            self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
                                                         dec_hook=decode_hook)
            self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)

52
53
54
55
56
        def get_node_ip(self) -> str:
            return get_ip()

        def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
            node_id = ray.get_runtime_context().get_node_id()
57
            device_key = vllm.platforms.current_platform.ray_device_key
58
59
            if not device_key:
                raise RuntimeError("current platform %s does not support ray.",
60
                                   vllm.platforms.current_platform.device_name)
61
62
            gpu_ids = ray.get_runtime_context().get_accelerator_ids(
            )[device_key]
63
64
            return node_id, gpu_ids

65
        def execute_model_spmd(
66
67
68
69
            self, req_or_tuple: Union[bytes,
                                      Tuple[bytes,
                                            Optional[IntermediateTensors]]]
        ) -> bytes:
70
71
72
73
            """Execute model in SPMD fashion: used only when SPMD worker and
            compiled DAG are both enabled.

            Args:
74
75
76
77
                req_or_tuple: A request or a tuple containing the
                    request and intermediate tensors. Intermediate tensors are
                    None unless if it is provided because it is > 0 pipeline
                    stage. The request is serialized by msgspec.
78
            """
79
80
81
82
83
84
85
            if isinstance(req_or_tuple, bytes):
                serialized_req, intermediate_tensors = req_or_tuple, None
            else:
                serialized_req, intermediate_tensors = req_or_tuple

            execute_model_req = self.input_decoder.decode(serialized_req)

86
87
88
            # TODO(swang): This is needed right now because Ray aDAG executes
            # on a background thread, so we need to reset torch's current
            # device.
89
90
91
92
93
            import torch
            if not self.compiled_dag_cuda_device_set:
                torch.cuda.set_device(self.worker.device)
                self.compiled_dag_cuda_device_set = True

94
95
            output = self.worker._execute_model_spmd(execute_model_req,
                                                     intermediate_tensors)
96
            # Pipeline model request and output to the next pipeline stage.
97
            if isinstance(output, IntermediateTensors):
98
99
100
101
                output = serialized_req, output
            else:
                output = self.output_encoder.encode(output)

102
            return output
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        def setup_device_if_necessary(self):
            # TODO(swang): This is needed right now because Ray CG executes
            # on a background thread, so we need to reset torch's current
            # device.
            # We can remove this API after it is fixed in compiled graph.
            import torch
            assert self.worker is not None, "Worker is not initialized"
            if not self.compiled_dag_cuda_device_set:
                torch.cuda.set_device(self.worker.device)
                self.compiled_dag_cuda_device_set = True

        def execute_model(
            self,
            scheduler_output: "SchedulerOutput",
        ) -> "ModelRunnerOutput":
            self.setup_device_if_necessary()
            assert self.worker is not None, "Worker is not initialized"
            output = self.worker.model_runner.execute_model(scheduler_output)
            return output

124
125
126
        def override_env_vars(self, vars: Dict[str, str]):
            os.environ.update(vars)

127
128
    ray_import_err = None

129
except ImportError as e:
130
    ray = None  # type: ignore
131
    ray_import_err = e
132
    RayWorkerWrapper = None  # type: ignore
133
134


135
136
137
138
139
140
141
142
143
144
145
146
def ray_is_available() -> bool:
    """Returns True if Ray is available."""
    return ray is not None


def assert_ray_available():
    """Raise an exception if Ray is not available."""
    if ray is None:
        raise ValueError("Failed to import Ray, please install Ray with "
                         "`pip install ray`.") from ray_import_err


147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def _verify_bundles(placement_group: "PlacementGroup",
                    parallel_config: ParallelConfig, device_str: str):
    """Verify a given placement group has bundles located in the right place.

    There are 2 rules.
    - Warn if all tensor parallel workers cannot fit in a single node.
    - Fail if driver node is not included in a placement group.
    """
    assert ray.is_initialized(), (
        "Ray is not initialized although distributed-executor-backend is ray.")
    pg_data = placement_group_table(placement_group)
    # bundle_idx -> node_id
    bundle_to_node_ids = pg_data["bundles_to_node_id"]
    # bundle_idx -> bundle (e.g., {"GPU": 1})
    bundles = pg_data["bundles"]
    # node_id -> List of bundle (e.g., {"GPU": 1})
    node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)

    for bundle_idx, node_id in bundle_to_node_ids.items():
        node_id_to_bundle[node_id].append(bundles[bundle_idx])
    driver_node_id = ray.get_runtime_context().get_node_id()

    if driver_node_id not in node_id_to_bundle:
        raise RuntimeError(
            f"driver node id {driver_node_id} is not included in a placement "
            f"group {placement_group.id}. Node id -> bundles "
            f"{node_id_to_bundle}. "
            "You don't have enough GPUs available in a current node. Check "
            "`ray status` to see if you have available GPUs in a node "
            f"{driver_node_id} before starting an vLLM engine.")

    for node_id, bundles in node_id_to_bundle.items():
        if len(bundles) < parallel_config.tensor_parallel_size:
            logger.warning(
                "tensor_parallel_size=%d "
                "is bigger than a reserved number of %ss (%d "
                "%ss) in a node %s. Tensor parallel workers can be "
                "spread out to 2+ nodes which can degrade the performance "
                "unless you have fast interconnect across nodes, like "
                "Infiniband. To resolve this issue, make sure you have more "
                "than %d GPUs available at each node.",
                parallel_config.tensor_parallel_size, device_str, len(bundles),
                device_str, node_id, parallel_config.tensor_parallel_size)


def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
    """Wait until a placement group is ready.

    It prints the informative log messages if the placement group is
    not created within time.

    """
    # Wait until PG is ready - this will block until all
    # requested resources are available, and will timeout
    # if they cannot be provisioned.
    placement_group_specs = current_placement_group.bundle_specs

    s = time.time()
    pg_ready_ref = current_placement_group.ready()
    wait_interval = 10
    while time.time() - s < PG_WAIT_TIMEOUT:
        ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
        if len(ready) > 0:
            break

        # Exponential backoff for warning print.
        wait_interval *= 2
        logger.info(
            "Waiting for creating a placement group of specs for "
            "%d seconds. specs=%s. Check "
217
218
219
220
            "`ray status` to see if you have enough resources,"
            " and make sure the IP addresses used by ray cluster"
            " are the same as VLLM_HOST_IP environment variable"
            " specified in each node if you are running on a multi-node.",
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
            int(time.time() - s), placement_group_specs)

    try:
        ray.get(pg_ready_ref, timeout=0)
    except ray.exceptions.GetTimeoutError:
        raise ValueError(
            "Cannot provide a placement group of "
            f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
            "`ray status` to make sure the cluster has enough resources."
        ) from None


def _wait_until_pg_removed(current_placement_group: "PlacementGroup"):
    ray.util.remove_placement_group(current_placement_group)
    s = time.time()
    wait_interval = 10
    while time.time() - s < PG_WAIT_TIMEOUT:
        pg = ray.util.get_current_placement_group()
        if pg is None:
            break

        # Exponential backoff for warning print.
        wait_interval *= 2
        logger.info(
            "Waiting for removing a placement group of specs for "
            "%d seconds.", int(time.time() - s))
        time.sleep(wait_interval)


250
def initialize_ray_cluster(
251
    parallel_config: ParallelConfig,
Zhuohan Li's avatar
Zhuohan Li committed
252
    ray_address: Optional[str] = None,
253
254
255
256
257
258
):
    """Initialize the distributed cluster with Ray.

    it will connect to the Ray cluster and create a placement group
    for the workers, which includes the specification of the resources
    for each distributed worker.
259
260
261

    Args:
        parallel_config: The configurations for parallel execution.
Zhuohan Li's avatar
Zhuohan Li committed
262
        ray_address: The address of the Ray cluster. If None, uses
263
264
            the default Ray cluster address.
    """
265
    assert_ray_available()
266
    from vllm.platforms import current_platform
267
268

    # Connect to a ray cluster.
269
    if current_platform.is_rocm() or current_platform.is_xpu():
270
271
        # Try to connect existing ray instance and create a new one if not found
        try:
272
            ray.init("auto", ignore_reinit_error=True)
273
274
275
276
277
278
279
        except ConnectionError:
            logger.warning(
                "No existing RAY instance detected. "
                "A new instance will be launched with current node resources.")
            ray.init(address=ray_address,
                     ignore_reinit_error=True,
                     num_gpus=parallel_config.world_size)
280
281
282
283
284
285
    else:
        ray.init(address=ray_address, ignore_reinit_error=True)

    if parallel_config.placement_group:
        # Placement group is already set.
        return
286

287
288
289
290
291
292
    device_str = current_platform.ray_device_key
    if not device_str:
        raise ValueError(
            f"current platform {current_platform.device_name} does not "
            "support ray.")

293
    # Create placement group for worker processes
294
295
296
297
298
    current_placement_group = ray.util.get_current_placement_group()
    if current_placement_group:
        # We are in a placement group
        bundles = current_placement_group.bundle_specs
        # Verify that we can use the placement group.
299
        device_bundles = 0
300
        for bundle in bundles:
301
302
            bundle_devices = bundle.get(device_str, 0)
            if bundle_devices > 1:
303
                raise ValueError(
304
305
306
307
308
                    "Placement group bundle cannot have more than 1 "
                    f"{device_str}.")
            if bundle_devices:
                device_bundles += 1
        if parallel_config.world_size > device_bundles:
309
            raise ValueError(
310
311
312
313
                f"The number of required {device_str}s exceeds the total "
                f"number of available {device_str}s in the placement group."
                f"Required number of devices: {parallel_config.world_size}. "
                f"Total number of devices: {device_bundles}.")
314
    else:
315
        num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
316
317
318
        # Log a warning message and delay resource allocation failure response.
        # Avoid immediate rejection to allow user-initiated placement group
        # created and wait cluster to be ready
319
        if parallel_config.world_size > num_devices_in_cluster:
320
321
322
323
            logger.warning(
                "The number of required %ss exceeds the total "
                "number of available %ss in the placement group.", device_str,
                device_str)
324
        # Create a new placement group
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        placement_group_specs: List[Dict[str, float]] = ([{
            device_str: 1.0
        } for _ in range(parallel_config.world_size)])

        # vLLM engine is also a worker to execute model with an accelerator,
        # so it requires to have the device in a current node. Check if
        # the current node has at least one device.
        current_ip = get_ip()
        current_node_id = ray.get_runtime_context().get_node_id()
        current_node_resource = available_resources_per_node()[current_node_id]
        if current_node_resource.get(device_str, 0) < 1:
            raise ValueError(
                f"Current node has no {device_str} available. "
                f"{current_node_resource=}. vLLM engine cannot start without "
                f"{device_str}. Make sure you have at least 1 {device_str} "
                f"available in a node {current_node_id=} {current_ip=}.")
        # This way, at least bundle is required to be created in a current
        # node.
        placement_group_specs[0][f"node:{current_ip}"] = 0.001

        # By default, Ray packs resources as much as possible.
346
        current_placement_group = ray.util.placement_group(
347
348
            placement_group_specs, strategy="PACK")
        _wait_until_pg_ready(current_placement_group)
349

350
351
    assert current_placement_group is not None
    _verify_bundles(current_placement_group, parallel_config, device_str)
352
353
    # Set the placement group in the parallel config
    parallel_config.placement_group = current_placement_group
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378


def get_num_tpu_nodes() -> int:
    from ray._private.accelerators import TPUAcceleratorManager
    cluster_resources = ray.cluster_resources()
    total_tpus = int(cluster_resources["TPU"])
    tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
    assert total_tpus % tpus_per_node == 0
    return total_tpus // tpus_per_node


def get_num_nodes_in_placement_group() -> int:
    pg_table = ray.util.placement_group_table()
    current_pg = ray.util.get_current_placement_group()
    num_nodes = 0

    if current_pg:
        nodes_in_pg = set()
        for pg_key, pg in pg_table.items():
            if pg_key == current_pg.id.hex():
                for _, node in pg["bundles_to_node_id"].items():
                    nodes_in_pg.add(node)
        num_nodes = len(nodes_in_pg)

    return num_nodes