ray_utils.py 17.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
6
import time
from collections import defaultdict
7
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8

9
10
import msgspec

11
import vllm.platforms
12
from vllm.config import ParallelConfig
13
from vllm.distributed import get_pp_group
14
from vllm.executor.msgspec_utils import decode_hook, encode_hook
15
from vllm.logger import init_logger
16
from vllm.platforms import current_platform
17
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
18
from vllm.utils import get_ip
19
from vllm.v1.outputs import AsyncModelRunnerOutput
20
from vllm.v1.worker.worker_base import WorkerWrapperBase
21

22
if TYPE_CHECKING:
23
    from vllm.v1.core.sched.output import SchedulerOutput
24
25
    from vllm.v1.outputs import ModelRunnerOutput

26
logger = init_logger(__name__)
27
PG_WAIT_TIMEOUT = 1800
28
29
30

try:
    import ray
31
32
    from ray.util import placement_group_table
    from ray.util.placement_group import PlacementGroup
33
34
35
36
37
38
    try:
        from ray._private.state import available_resources_per_node
    except ImportError:
        # Ray 2.9.x doesn't expose `available_resources_per_node`
        from ray._private.state import state as _state
        available_resources_per_node = _state._available_resources_per_node
39

40
    class RayWorkerWrapper(WorkerWrapperBase):
41
        """Ray wrapper for vllm.worker.Worker, allowing Worker to be
42
        lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""
43

44
45
        def __init__(self, *args, **kwargs) -> None:
            super().__init__(*args, **kwargs)
46
47
48
49
50
            # Since the compiled DAG runs a main execution
            # in a different thread that calls cuda.set_device.
            # The flag indicates is set_device is called on
            # that thread.
            self.compiled_dag_cuda_device_set = False
51

52
53
54
55
            self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
                                                         dec_hook=decode_hook)
            self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)

56
57
58
59
60
        def get_node_ip(self) -> str:
            return get_ip()

        def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
            node_id = ray.get_runtime_context().get_node_id()
61
            device_key = vllm.platforms.current_platform.ray_device_key
62
63
            if not device_key:
                raise RuntimeError("current platform %s does not support ray.",
64
                                   vllm.platforms.current_platform.device_name)
65
66
            gpu_ids = ray.get_runtime_context().get_accelerator_ids(
            )[device_key]
67
68
            return node_id, gpu_ids

69
        def execute_model_spmd(
70
71
72
73
            self, req_or_tuple: Union[bytes,
                                      Tuple[bytes,
                                            Optional[IntermediateTensors]]]
        ) -> bytes:
74
75
76
77
            """Execute model in SPMD fashion: used only when SPMD worker and
            compiled DAG are both enabled.

            Args:
78
79
80
81
                req_or_tuple: A request or a tuple containing the
                    request and intermediate tensors. Intermediate tensors are
                    None unless if it is provided because it is > 0 pipeline
                    stage. The request is serialized by msgspec.
82
            """
83
84
85
86
87
88
89
            if isinstance(req_or_tuple, bytes):
                serialized_req, intermediate_tensors = req_or_tuple, None
            else:
                serialized_req, intermediate_tensors = req_or_tuple

            execute_model_req = self.input_decoder.decode(serialized_req)

90
91
92
            # TODO(swang): This is needed right now because Ray Compiled Graph
            # executes on a background thread, so we need to reset torch's
            # current device.
93
            if not self.compiled_dag_cuda_device_set:
94
                current_platform.set_device(self.worker.device)
95
96
                self.compiled_dag_cuda_device_set = True

97
98
            output = self.worker._execute_model_spmd(execute_model_req,
                                                     intermediate_tensors)
99
            # Pipeline model request and output to the next pipeline stage.
100
            if isinstance(output, IntermediateTensors):
101
102
103
104
                output = serialized_req, output
            else:
                output = self.output_encoder.encode(output)

105
            return output
106

107
108
109
110
111
112
113
        def setup_device_if_necessary(self):
            # TODO(swang): This is needed right now because Ray CG executes
            # on a background thread, so we need to reset torch's current
            # device.
            # We can remove this API after it is fixed in compiled graph.
            assert self.worker is not None, "Worker is not initialized"
            if not self.compiled_dag_cuda_device_set:
114
115
116
117
                if current_platform.is_tpu():
                    # Not needed
                    pass
                else:
118
                    current_platform.set_device(self.worker.device)
119

120
121
                self.compiled_dag_cuda_device_set = True

122
        def execute_model_ray(
123
            self,
124
125
126
127
128
            scheduler_output: Union["SchedulerOutput",
                                    Tuple["SchedulerOutput",
                                          "IntermediateTensors"]],
        ) -> Union["ModelRunnerOutput", Tuple["SchedulerOutput",
                                              "IntermediateTensors"]]:
129
            # This method is used by Ray Compiled Graph to execute the model,
130
            # and it needs a special logic of self.setup_device_if_necessary()
131
132
            self.setup_device_if_necessary()
            assert self.worker is not None, "Worker is not initialized"
133
134
135
136
137
138
139
140
            if isinstance(scheduler_output, tuple):
                scheduler_output, intermediate_tensors = scheduler_output
            else:
                scheduler_output, intermediate_tensors = scheduler_output, None
            output = self.worker.model_runner.execute_model(
                scheduler_output, intermediate_tensors)
            if isinstance(output, IntermediateTensors):
                output = scheduler_output, output
141
142
143
144
145
            elif not get_pp_group().is_last_rank:
                # Case where there are no scheduled requests
                # but may still be finished requests.
                assert not output or not output.req_ids
                output = scheduler_output, None
146
147
148
149
150
            # Ensure outputs crossing Ray compiled DAG are serializable.
            # AsyncModelRunnerOutput holds CUDA events and cannot be
            # pickled.
            if isinstance(output, AsyncModelRunnerOutput):
                output = output.get_output()
151
152
            return output

153
154
155
        def override_env_vars(self, vars: Dict[str, str]):
            os.environ.update(vars)

156
157
    ray_import_err = None

158
except ImportError as e:
159
    ray = None  # type: ignore
160
161
162
    # only capture string to avoid variable references in the traceback that can
    # prevent garbage collection in some cases
    ray_import_err = str(e)
163
    RayWorkerWrapper = None  # type: ignore
164
165


166
167
168
169
170
171
172
173
def ray_is_available() -> bool:
    """Returns True if Ray is available."""
    return ray is not None


def assert_ray_available():
    """Raise an exception if Ray is not available."""
    if ray is None:
174
175
        raise ValueError(f"Failed to import Ray: {ray_import_err}."
                         "Please install Ray with `pip install ray`.")
176
177


178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def _verify_bundles(placement_group: "PlacementGroup",
                    parallel_config: ParallelConfig, device_str: str):
    """Verify a given placement group has bundles located in the right place.

    There are 2 rules.
    - Warn if all tensor parallel workers cannot fit in a single node.
    - Fail if driver node is not included in a placement group.
    """
    assert ray.is_initialized(), (
        "Ray is not initialized although distributed-executor-backend is ray.")
    pg_data = placement_group_table(placement_group)
    # bundle_idx -> node_id
    bundle_to_node_ids = pg_data["bundles_to_node_id"]
    # bundle_idx -> bundle (e.g., {"GPU": 1})
    bundles = pg_data["bundles"]
    # node_id -> List of bundle (e.g., {"GPU": 1})
    node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)

    for bundle_idx, node_id in bundle_to_node_ids.items():
        node_id_to_bundle[node_id].append(bundles[bundle_idx])
    driver_node_id = ray.get_runtime_context().get_node_id()

    if driver_node_id not in node_id_to_bundle:
        raise RuntimeError(
            f"driver node id {driver_node_id} is not included in a placement "
            f"group {placement_group.id}. Node id -> bundles "
            f"{node_id_to_bundle}. "
            "You don't have enough GPUs available in a current node. Check "
206
207
208
            "`ray status` and `ray list nodes` to see if you have available "
            "GPUs in a node `{driver_node_id}` before starting an vLLM engine."
        )
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231

    for node_id, bundles in node_id_to_bundle.items():
        if len(bundles) < parallel_config.tensor_parallel_size:
            logger.warning(
                "tensor_parallel_size=%d "
                "is bigger than a reserved number of %ss (%d "
                "%ss) in a node %s. Tensor parallel workers can be "
                "spread out to 2+ nodes which can degrade the performance "
                "unless you have fast interconnect across nodes, like "
                "Infiniband. To resolve this issue, make sure you have more "
                "than %d GPUs available at each node.",
                parallel_config.tensor_parallel_size, device_str, len(bundles),
                device_str, node_id, parallel_config.tensor_parallel_size)


def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
    """Wait until a placement group is ready.

    It prints the informative log messages if the placement group is
    not created within time.

    """
    # Wait until PG is ready - this will block until all
232
    # requested resources are available, and will time out
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    # if they cannot be provisioned.
    placement_group_specs = current_placement_group.bundle_specs

    s = time.time()
    pg_ready_ref = current_placement_group.ready()
    wait_interval = 10
    while time.time() - s < PG_WAIT_TIMEOUT:
        ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
        if len(ready) > 0:
            break

        # Exponential backoff for warning print.
        wait_interval *= 2
        logger.info(
            "Waiting for creating a placement group of specs for "
248
249
            "%d seconds. specs=%s. Check `ray status` and "
            "`ray list nodes` to see if you have enough resources,"
250
251
252
            " and make sure the IP addresses used by ray cluster"
            " are the same as VLLM_HOST_IP environment variable"
            " specified in each node if you are running on a multi-node.",
253
254
255
256
257
258
259
260
            int(time.time() - s), placement_group_specs)

    try:
        ray.get(pg_ready_ref, timeout=0)
    except ray.exceptions.GetTimeoutError:
        raise ValueError(
            "Cannot provide a placement group of "
            f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
261
262
            "`ray status` and `ray list nodes` to make sure the cluster has "
            "enough resources.") from None
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281


def _wait_until_pg_removed(current_placement_group: "PlacementGroup"):
    ray.util.remove_placement_group(current_placement_group)
    s = time.time()
    wait_interval = 10
    while time.time() - s < PG_WAIT_TIMEOUT:
        pg = ray.util.get_current_placement_group()
        if pg is None:
            break

        # Exponential backoff for warning print.
        wait_interval *= 2
        logger.info(
            "Waiting for removing a placement group of specs for "
            "%d seconds.", int(time.time() - s))
        time.sleep(wait_interval)


282
def initialize_ray_cluster(
283
    parallel_config: ParallelConfig,
Zhuohan Li's avatar
Zhuohan Li committed
284
    ray_address: Optional[str] = None,
285
286
287
288
289
290
):
    """Initialize the distributed cluster with Ray.

    it will connect to the Ray cluster and create a placement group
    for the workers, which includes the specification of the resources
    for each distributed worker.
291
292
293

    Args:
        parallel_config: The configurations for parallel execution.
Zhuohan Li's avatar
Zhuohan Li committed
294
        ray_address: The address of the Ray cluster. If None, uses
295
296
            the default Ray cluster address.
    """
297
    assert_ray_available()
298
    from vllm.platforms import current_platform
299

300
301
302
    if ray.is_initialized():
        logger.info("Ray is already initialized. Skipping Ray initialization.")
    elif current_platform.is_rocm() or current_platform.is_xpu():
303
304
        # Try to connect existing ray instance and create a new one if not found
        try:
305
            ray.init("auto")
306
307
308
309
        except ConnectionError:
            logger.warning(
                "No existing RAY instance detected. "
                "A new instance will be launched with current node resources.")
310
311
312
            ray.init(address=ray_address,
                     num_gpus=parallel_config.world_size,
                     runtime_env=parallel_config.ray_runtime_env)
313
    else:
314
315
        ray.init(address=ray_address,
                 runtime_env=parallel_config.ray_runtime_env)
316

317
318
319
320
321
322
    device_str = current_platform.ray_device_key
    if not device_str:
        raise ValueError(
            f"current platform {current_platform.device_name} does not "
            "support ray.")

323
324
325
326
327
328
    # Create or get the placement group for worker processes
    if parallel_config.placement_group:
        current_placement_group = parallel_config.placement_group
    else:
        current_placement_group = ray.util.get_current_placement_group()

329
    if current_placement_group:
330
331
        logger.info("Using the existing placement group")

332
333
334
        # We are in a placement group
        bundles = current_placement_group.bundle_specs
        # Verify that we can use the placement group.
335
        device_bundles = 0
336
        for bundle in bundles:
337
338
            bundle_devices = bundle.get(device_str, 0)
            if bundle_devices > 1:
339
                raise ValueError(
340
341
342
343
344
                    "Placement group bundle cannot have more than 1 "
                    f"{device_str}.")
            if bundle_devices:
                device_bundles += 1
        if parallel_config.world_size > device_bundles:
345
            raise ValueError(
346
                f"The number of required {device_str}s exceeds the total "
347
                f"number of available {device_str}s in the placement group. "
348
349
                f"Required number of devices: {parallel_config.world_size}. "
                f"Total number of devices: {device_bundles}.")
350
    else:
351
352
        logger.info("No current placement group found. "
                    "Creating a new placement group.")
353
        num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
354
355
356
        # Log a warning message and delay resource allocation failure response.
        # Avoid immediate rejection to allow user-initiated placement group
        # created and wait cluster to be ready
357
        if parallel_config.world_size > num_devices_in_cluster:
358
359
360
361
            logger.warning(
                "The number of required %ss exceeds the total "
                "number of available %ss in the placement group.", device_str,
                device_str)
362
        # Create a new placement group
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        placement_group_specs: List[Dict[str, float]] = ([{
            device_str: 1.0
        } for _ in range(parallel_config.world_size)])

        # vLLM engine is also a worker to execute model with an accelerator,
        # so it requires to have the device in a current node. Check if
        # the current node has at least one device.
        current_ip = get_ip()
        current_node_id = ray.get_runtime_context().get_node_id()
        current_node_resource = available_resources_per_node()[current_node_id]
        if current_node_resource.get(device_str, 0) < 1:
            raise ValueError(
                f"Current node has no {device_str} available. "
                f"{current_node_resource=}. vLLM engine cannot start without "
                f"{device_str}. Make sure you have at least 1 {device_str} "
                f"available in a node {current_node_id=} {current_ip=}.")
        # This way, at least bundle is required to be created in a current
        # node.
        placement_group_specs[0][f"node:{current_ip}"] = 0.001

        # By default, Ray packs resources as much as possible.
384
        current_placement_group = ray.util.placement_group(
385
386
            placement_group_specs, strategy="PACK")
        _wait_until_pg_ready(current_placement_group)
387

388
389
    assert current_placement_group is not None
    _verify_bundles(current_placement_group, parallel_config, device_str)
390
391
    # Set the placement group in the parallel config
    parallel_config.placement_group = current_placement_group
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416


def get_num_tpu_nodes() -> int:
    from ray._private.accelerators import TPUAcceleratorManager
    cluster_resources = ray.cluster_resources()
    total_tpus = int(cluster_resources["TPU"])
    tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
    assert total_tpus % tpus_per_node == 0
    return total_tpus // tpus_per_node


def get_num_nodes_in_placement_group() -> int:
    pg_table = ray.util.placement_group_table()
    current_pg = ray.util.get_current_placement_group()
    num_nodes = 0

    if current_pg:
        nodes_in_pg = set()
        for pg_key, pg in pg_table.items():
            if pg_key == current_pg.id.hex():
                for _, node in pg["bundles_to_node_id"].items():
                    nodes_in_pg.add(node)
        num_nodes = len(nodes_in_pg)

    return num_nodes