"vscode:/vscode.git/clone" did not exist on "fadb8d5c2df1c24d891aeccfb0b11de6e03e9f27"
ray_utils.py 24.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
6
import time
from collections import defaultdict
7
from concurrent.futures import Future
8
from typing import TYPE_CHECKING, Union
9

10
import vllm.platforms
11
from vllm.config import ParallelConfig
12
from vllm.distributed import get_pp_group
13
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
14
from vllm.logger import init_logger
15
from vllm.platforms import current_platform
16
from vllm.sequence import IntermediateTensors
17
from vllm.utils.network_utils import get_ip
18
from vllm.v1.outputs import AsyncModelRunnerOutput
19
from vllm.v1.serial_utils import run_method
20
from vllm.v1.worker.worker_base import WorkerWrapperBase
21

22
if TYPE_CHECKING:
23
    from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
24
25
    from vllm.v1.outputs import ModelRunnerOutput

26
logger = init_logger(__name__)
27
PG_WAIT_TIMEOUT = 1800
28

29
30
31
32
33
34
35
36
37
38
39
# Env vars that are worker-specific and must NOT be copied from the
# driver to Ray workers — they are set per-worker after GPU discovery.
WORKER_SPECIFIC_ENV_VARS: set[str] = {
    "VLLM_HOST_IP",
    "VLLM_HOST_PORT",
    "LOCAL_RANK",
    "CUDA_VISIBLE_DEVICES",
    "HIP_VISIBLE_DEVICES",
    "ROCR_VISIBLE_DEVICES",
}

40
41
try:
    import ray
42
43
    from ray.util import placement_group_table
    from ray.util.placement_group import PlacementGroup
44

45
46
47
48
49
    try:
        from ray._private.state import available_resources_per_node
    except ImportError:
        # Ray 2.9.x doesn't expose `available_resources_per_node`
        from ray._private.state import state as _state
50

51
        available_resources_per_node = _state._available_resources_per_node
52

53
    class RayWorkerWrapper(WorkerWrapperBase):
54
        """Ray wrapper for vllm.worker.Worker, allowing Worker to be
55
        lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""
56

57
58
        def __init__(self, *args, **kwargs) -> None:
            super().__init__(*args, **kwargs)
59
60
61
62
63
            # Since the compiled DAG runs a main execution
            # in a different thread that calls cuda.set_device.
            # The flag indicates is set_device is called on
            # that thread.
            self.compiled_dag_cuda_device_set = False
64

65
66
        rpc_rank: int

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        def adjust_rank(self, rank_mapping: dict[int, int]) -> None:
            """
            Adjust the rpc_rank based on the given mapping.
            It is only used during the initialization of the executor,
            to adjust the rpc_rank of workers after we create all workers.
            """
            if self.rpc_rank in rank_mapping:
                self.rpc_rank = rank_mapping[self.rpc_rank]

        def execute_method(self, method: str | bytes, *args, **kwargs):
            try:
                return run_method(self, method, args, kwargs)
            except Exception as e:
                # if the driver worker also execute methods,
                # exceptions in the rest worker may cause deadlock in rpc
                # see https://github.com/vllm-project/vllm/issues/3455
                msg = (
                    f"Error executing method {method!r}. "
                    "This might cause deadlock in distributed execution."
                )
                logger.exception(msg)
                raise e

90
91
92
        def get_node_ip(self) -> str:
            return get_ip()

93
        def get_node_and_gpu_ids(self) -> tuple[str, list[int]]:
94
            node_id = ray.get_runtime_context().get_node_id()
95
            device_key = vllm.platforms.current_platform.ray_device_key
96
            if not device_key:
97
98
99
100
101
                raise RuntimeError(
                    "current platform %s does not support ray.",
                    vllm.platforms.current_platform.device_name,
                )
            gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key]
102
103
            return node_id, gpu_ids

104
105
106
107
108
109
110
        def setup_device_if_necessary(self):
            # TODO(swang): This is needed right now because Ray CG executes
            # on a background thread, so we need to reset torch's current
            # device.
            # We can remove this API after it is fixed in compiled graph.
            assert self.worker is not None, "Worker is not initialized"
            if not self.compiled_dag_cuda_device_set:
111
112
113
114
                if current_platform.is_tpu():
                    # Not needed
                    pass
                else:
115
                    assert self.worker.device is not None
116
                    current_platform.set_device(self.worker.device)
117

118
119
                self.compiled_dag_cuda_device_set = True

120
        def execute_model_ray(
121
            self,
122
123
            execute_model_input: tuple["SchedulerOutput", "GrammarOutput"]
            | tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
124
        ) -> Union[
125
126
            "ModelRunnerOutput",
            tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
127
        ]:
128
            # This method is used by Ray Compiled Graph to execute the model,
129
            # and it needs a special logic of self.setup_device_if_necessary()
130
131
            self.setup_device_if_necessary()
            assert self.worker is not None, "Worker is not initialized"
132
133
134
135
            if len(execute_model_input) == 3:
                scheduler_output, grammar_output, intermediate_tensors = (
                    execute_model_input
                )
136
            else:
137
138
                scheduler_output, grammar_output = execute_model_input
                intermediate_tensors = None
139
            assert self.worker.model_runner is not None
140
            output = self.worker.model_runner.execute_model(
141
142
                scheduler_output, intermediate_tensors
            )
143
            if self._is_intermediate_tensors(output):
144
145
146
147
148
149
150
151
152
153
154
155
                if (
                    self.worker.model_runner.supports_mm_inputs
                    and get_pp_group().is_first_rank
                ):
                    # Strip mm_features before Ray forwards it to the next PP Stage.
                    # PP Stage>0 only needs the intermediate tensors,
                    # not preprocessed multimodal data.

                    # scheduled_new_reqs is a required field of SchedulerOutput,
                    # so accessing it directly will raise AttributeError if missing.
                    for req in scheduler_output.scheduled_new_reqs:
                        req.mm_features = []
156
157
158
159
                return scheduler_output, grammar_output, output

            if isinstance(output, AsyncModelRunnerOutput):
                output = output.get_output()
160
            if not self._is_last_rank():
161
162
163
                # Case where there are no scheduled requests
                # but may still be finished requests.
                assert not output or not output.req_ids
164
165
166
167
168
169
170
171
                output = scheduler_output, grammar_output, None
            elif output is None:
                output = self.worker.model_runner.sample_tokens(grammar_output)
                # Ensure outputs crossing Ray compiled DAG are serializable.
                # AsyncModelRunnerOutput holds CUDA events and cannot be
                # pickled.
                if isinstance(output, AsyncModelRunnerOutput):
                    output = output.get_output()
172
173
            return output

174
        def override_env_vars(self, vars: dict[str, str]):
175
176
            os.environ.update(vars)

177
178
179
        def _is_intermediate_tensors(self, output) -> bool:
            return isinstance(output, IntermediateTensors)

180
181
182
        def _is_last_rank(self) -> bool:
            return get_pp_group().is_last_rank

183
184
    ray_import_err = None

185
except ImportError as e:
186
    ray = None  # type: ignore
187
188
189
    # only capture string to avoid variable references in the traceback that can
    # prevent garbage collection in some cases
    ray_import_err = str(e)
190
    RayWorkerWrapper = None  # type: ignore
191
192


193
194
195
196
197
198
199
200
201
class FutureWrapper(Future):
    """A wrapper around Ray output reference to meet the interface
    of .execute_model(): The top level (core busy loop) expects .result() api
    to block and return a single output.

    If aggregator is provided, the outputs from all workers are aggregated upon
    the result() call. If not only the first worker's output is returned.
    """

202
    def __init__(self, ref_or_refs, aggregator: KVOutputAggregator | None = None):
203
        super().__init__()
204
        self.ref_or_refs = ref_or_refs
205
206
207
        self.aggregator = aggregator

    def result(self, timeout=None):
208
        outputs = ray.get(self.ref_or_refs, timeout=timeout)
209
        if self.aggregator is None:
210
            return outputs
211
212
213
214

        return self.aggregator.aggregate(outputs, output_rank=0)


215
216
217
218
219
220
221
222
def ray_is_available() -> bool:
    """Returns True if Ray is available."""
    return ray is not None


def assert_ray_available():
    """Raise an exception if Ray is not available."""
    if ray is None:
223
224
225
226
        raise ValueError(
            f"Failed to import Ray: {ray_import_err}."
            "Please install Ray with `pip install ray`."
        )
227
228


229
def _verify_bundles(
230
231
232
233
    placement_group: "PlacementGroup",
    parallel_config: ParallelConfig,
    device_str: str,
    require_gpu_on_driver: bool = True,
234
):
235
236
237
238
    """Verify a given placement group has bundles located in the right place.

    There are 2 rules.
    - Warn if all tensor parallel workers cannot fit in a single node.
239
240
    - Fail if driver node is not included in a placement group
      (only when require_gpu_on_driver is True).
241
242
    """
    assert ray.is_initialized(), (
243
244
        "Ray is not initialized although distributed-executor-backend is ray."
    )
245
246
247
248
249
250
    pg_data = placement_group_table(placement_group)
    # bundle_idx -> node_id
    bundle_to_node_ids = pg_data["bundles_to_node_id"]
    # bundle_idx -> bundle (e.g., {"GPU": 1})
    bundles = pg_data["bundles"]
    # node_id -> List of bundle (e.g., {"GPU": 1})
251
    node_id_to_bundle: dict[str, list[dict[str, float]]] = defaultdict(list)
252
253
254
255
256

    for bundle_idx, node_id in bundle_to_node_ids.items():
        node_id_to_bundle[node_id].append(bundles[bundle_idx])
    driver_node_id = ray.get_runtime_context().get_node_id()

257
    if require_gpu_on_driver and driver_node_id not in node_id_to_bundle:
258
259
260
261
262
        raise RuntimeError(
            f"driver node id {driver_node_id} is not included in a placement "
            f"group {placement_group.id}. Node id -> bundles "
            f"{node_id_to_bundle}. "
            "You don't have enough GPUs available in a current node. Check "
263
264
265
            "`ray status` and `ray list nodes` to see if you have available "
            "GPUs in a node `{driver_node_id}` before starting an vLLM engine."
        )
266
267
268
269
270
271
272
273
274
275
276

    for node_id, bundles in node_id_to_bundle.items():
        if len(bundles) < parallel_config.tensor_parallel_size:
            logger.warning(
                "tensor_parallel_size=%d "
                "is bigger than a reserved number of %ss (%d "
                "%ss) in a node %s. Tensor parallel workers can be "
                "spread out to 2+ nodes which can degrade the performance "
                "unless you have fast interconnect across nodes, like "
                "Infiniband. To resolve this issue, make sure you have more "
                "than %d GPUs available at each node.",
277
278
279
280
281
282
283
                parallel_config.tensor_parallel_size,
                device_str,
                len(bundles),
                device_str,
                node_id,
                parallel_config.tensor_parallel_size,
            )
284
285


286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def build_actor_name(
    instance_id: str,
    rank: int,
    tp_size: int,
    pp_size: int,
    pcp_size: int,
) -> str:
    """Build a descriptive Ray actor name for dashboard visibility."""
    name = f"vllm_Worker_{instance_id}"
    if tp_size > 1:
        name += f"_TP{rank % tp_size}"
    if pp_size > 1:
        name += f"_PP{(rank // tp_size) % pp_size}"
    if pcp_size > 1:
        name += f"_PCP{rank // (tp_size * pp_size)}"
    return name


def get_bundles_for_indices(
    placement_group: "PlacementGroup",
    bundle_indices: list[int],
    world_size: int,
) -> list[tuple[int, str, str]]:
    """
    Return GPU bundle indices paired with node IDs and node IPs for
    explicit bundle indices specified via VLLM_RAY_BUNDLE_INDICES.
    """
    assert len(bundle_indices) == world_size, (
        "VLLM_RAY_BUNDLE_INDICES must have the same size"
        f" as the world size, but got {bundle_indices=} "
        f"and {world_size=}"
    )
    assert len(set(bundle_indices)) == len(bundle_indices), (
        "VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
        f" but got {bundle_indices=}"
    )

    pg_data = placement_group_table(placement_group)
    pg_bundle_to_node = pg_data["bundles_to_node_id"]
    node_id_to_ip = {
        n["NodeID"]: n["NodeManagerAddress"] for n in ray.nodes() if n["Alive"]
    }
    return [
        (bid, pg_bundle_to_node[bid], node_id_to_ip[pg_bundle_to_node[bid]])
        for bid in bundle_indices
    ]


def get_bundles_sorted_by_node(
    placement_group: "PlacementGroup",
) -> list[tuple[int, str, str]]:
    """
    Return GPU bundle indices paired with node IDs and node IPs,
    sorted driver-first.

    This utility has to be invoked from the driver node.

    Example: 3-node cluster, driver on node-A, PG bundles spread
    across nodes:

      Input: [
          (0, node-C),
          (1, node-A),
          (2, node-B),
          (3, node-C),
          (4, node-A),
          (5, node-B),
      ]
      Output: [
          (1, node-A),
          (4, node-A),
          (2, node-B),
          (5, node-B),
          (0, node-C),
          (3, node-C),
      ]
    """
    pg_data = placement_group_table(placement_group)
    bundle_to_node = pg_data["bundles_to_node_id"]

    ray_device_key = current_platform.ray_device_key
    if not ray_device_key:
        raise ValueError(
            f"current platform {current_platform.device_name} does not support ray."
        )

    node_id_to_ip = {
        n["NodeID"]: n["NodeManagerAddress"] for n in ray.nodes() if n["Alive"]
    }

    bundle_specs = placement_group.bundle_specs
    assert bundle_specs is not None
    bundle_to_node_id: list[tuple[int, str, str]] = []
    for bundle_idx, bundle in enumerate(bundle_specs):
        if bundle.get(ray_device_key):
            node_id = bundle_to_node.get(bundle_idx)
            bundle_to_node_id.append((bundle_idx, node_id, node_id_to_ip[node_id]))

    driver_node = ray.get_runtime_context().get_node_id()

    def _sort_key(item):
        _, node_id, _ = item
        return (0 if node_id == driver_node else 1, node_id)

    bundle_to_node_id.sort(key=_sort_key)

    return bundle_to_node_id


395
396
397
398
399
400
401
402
def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
    """Wait until a placement group is ready.

    It prints the informative log messages if the placement group is
    not created within time.

    """
    # Wait until PG is ready - this will block until all
403
    # requested resources are available, and will time out
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    # if they cannot be provisioned.
    placement_group_specs = current_placement_group.bundle_specs

    s = time.time()
    pg_ready_ref = current_placement_group.ready()
    wait_interval = 10
    while time.time() - s < PG_WAIT_TIMEOUT:
        ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
        if len(ready) > 0:
            break

        # Exponential backoff for warning print.
        wait_interval *= 2
        logger.info(
            "Waiting for creating a placement group of specs for "
419
420
            "%d seconds. specs=%s. Check `ray status` and "
            "`ray list nodes` to see if you have enough resources,"
421
422
423
            " and make sure the IP addresses used by ray cluster"
            " are the same as VLLM_HOST_IP environment variable"
            " specified in each node if you are running on a multi-node.",
424
425
426
            int(time.time() - s),
            placement_group_specs,
        )
427
428
429
430

    try:
        ray.get(pg_ready_ref, timeout=0)
    except ray.exceptions.GetTimeoutError:
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
        # Provide more helpful error message when GPU count is exceeded
        total_gpu_required = sum(spec.get("GPU", 0) for spec in placement_group_specs)
        # If more than one GPU is required for the placement group, provide a
        # more specific error message.
        # We use >1 here because multi-GPU (tensor parallel) jobs are more
        # likely to fail due to insufficient cluster resources, and users may
        # need to adjust tensor_parallel_size to fit available GPUs.
        if total_gpu_required > 1:
            raise ValueError(
                f"Cannot provide a placement group requiring "
                f"{total_gpu_required} GPUs "
                f"(placement_group_specs={placement_group_specs}) within "
                f"{PG_WAIT_TIMEOUT} seconds.\n"
                f"Tensor parallel size may exceed available GPUs in your "
                f"cluster. Check resources with `ray status` and "
                f"`ray list nodes`.\n"
                f"If running on K8s with limited GPUs, consider reducing "
                f"--tensor-parallel-size to match available GPU resources."
            ) from None
        else:
            raise ValueError(
                "Cannot provide a placement group of "
                f"{placement_group_specs=} within "
                f"{PG_WAIT_TIMEOUT} seconds. See "
                "`ray status` and `ray list nodes` to make sure the cluster "
                "has enough resources."
            ) from None
458
459
460
461
462
463
464
465
466
467
468
469
470
471


def _wait_until_pg_removed(current_placement_group: "PlacementGroup"):
    ray.util.remove_placement_group(current_placement_group)
    s = time.time()
    wait_interval = 10
    while time.time() - s < PG_WAIT_TIMEOUT:
        pg = ray.util.get_current_placement_group()
        if pg is None:
            break

        # Exponential backoff for warning print.
        wait_interval *= 2
        logger.info(
472
473
474
            "Waiting for removing a placement group of specs for %d seconds.",
            int(time.time() - s),
        )
475
476
477
        time.sleep(wait_interval)


478
def initialize_ray_cluster(
479
    parallel_config: ParallelConfig,
480
    ray_address: str | None = None,
481
    require_gpu_on_driver: bool = True,
482
483
484
485
486
487
):
    """Initialize the distributed cluster with Ray.

    it will connect to the Ray cluster and create a placement group
    for the workers, which includes the specification of the resources
    for each distributed worker.
488
489
490

    Args:
        parallel_config: The configurations for parallel execution.
Zhuohan Li's avatar
Zhuohan Li committed
491
        ray_address: The address of the Ray cluster. If None, uses
492
            the default Ray cluster address.
493
494
495
496
        require_gpu_on_driver: If True (default), require at least one GPU
            on the current (driver) node and pin the first PG bundle to it.
            Set to False for executors like RayExecutorV2 where all GPU work
            is delegated to remote Ray actors.
497
    """
498
    assert_ray_available()
499
    from vllm.platforms import current_platform
500

501
502
503
504
    # Disable Ray usage stats collection
    if os.environ.get("RAY_USAGE_STATS_ENABLED", "0") != "1":
        os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

505
506
    # Prevalidate GPU requirements before Ray processing
    if current_platform.is_cuda() and parallel_config.world_size > 1:
507
        available_gpus = current_platform.device_count()
508
509
510
511
512
513
514
515
516
517
518
519
        if parallel_config.world_size > available_gpus:
            logger.warning(
                "Tensor parallel size (%d) exceeds available GPUs (%d). "
                "This may result in Ray placement group allocation failures. "
                "Consider reducing tensor_parallel_size to %d or less, "
                "or ensure your Ray cluster has %d GPUs available.",
                parallel_config.world_size,
                available_gpus,
                available_gpus,
                parallel_config.world_size,
            )

520
521
522
    if ray.is_initialized():
        logger.info("Ray is already initialized. Skipping Ray initialization.")
    elif current_platform.is_rocm() or current_platform.is_xpu():
523
524
        # Try to connect existing ray instance and create a new one if not found
        try:
525
            ray.init("auto")
526
527
528
        except ConnectionError:
            logger.warning(
                "No existing RAY instance detected. "
529
530
531
532
533
534
535
                "A new instance will be launched with current node resources."
            )
            ray.init(
                address=ray_address,
                num_gpus=parallel_config.world_size,
                runtime_env=parallel_config.ray_runtime_env,
            )
536
    else:
537
        ray.init(address=ray_address, runtime_env=parallel_config.ray_runtime_env)
538

539
540
541
    device_str = current_platform.ray_device_key
    if not device_str:
        raise ValueError(
542
543
            f"current platform {current_platform.device_name} does not support ray."
        )
544

545
546
547
548
549
550
    # Create or get the placement group for worker processes
    if parallel_config.placement_group:
        current_placement_group = parallel_config.placement_group
    else:
        current_placement_group = ray.util.get_current_placement_group()

551
    if current_placement_group:
552
553
        logger.info("Using the existing placement group")

554
555
556
        # We are in a placement group
        bundles = current_placement_group.bundle_specs
        # Verify that we can use the placement group.
557
        device_bundles = 0
558
        for bundle in bundles:
559
560
            bundle_devices = bundle.get(device_str, 0)
            if bundle_devices > 1:
561
                raise ValueError(
562
563
                    f"Placement group bundle cannot have more than 1 {device_str}."
                )
564
565
566
            if bundle_devices:
                device_bundles += 1
        if parallel_config.world_size > device_bundles:
567
            raise ValueError(
568
                f"The number of required {device_str}s exceeds the total "
569
                f"number of available {device_str}s in the placement group. "
570
                f"Required number of devices: {parallel_config.world_size}. "
571
572
                f"Total number of devices: {device_bundles}."
            )
573
    else:
574
        logger.info("No current placement group found. Creating a new placement group.")
575
        num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
576
577
578
        # Log a warning message and delay resource allocation failure response.
        # Avoid immediate rejection to allow user-initiated placement group
        # created and wait cluster to be ready
579
        if parallel_config.world_size > num_devices_in_cluster:
580
581
            logger.warning(
                "The number of required %ss exceeds the total "
582
583
584
585
                "number of available %ss in the placement group.",
                device_str,
                device_str,
            )
586
        # Create a new placement group
587
        placement_group_specs: list[dict[str, float]] = [
588
589
            {device_str: 1.0} for _ in range(parallel_config.world_size)
        ]
590
591
592
593
594
595
596

        # vLLM engine is also a worker to execute model with an accelerator,
        # so it requires to have the device in a current node. Check if
        # the current node has at least one device.
        current_ip = get_ip()
        current_node_id = ray.get_runtime_context().get_node_id()
        current_node_resource = available_resources_per_node()[current_node_id]
597
598
599
600
601
602
603
604
605
606
607
608
609
610
        # TODO (jeffreywang): require_gpu_on_driver should be always False
        # after deprecating RayDistributedExecutor.
        if require_gpu_on_driver:
            if current_node_resource.get(device_str, 0) < 1:
                raise ValueError(
                    f"Current node has no {device_str} available. "
                    f"{current_node_resource=}. vLLM engine cannot start "
                    f"without {device_str}. Make sure you have at least 1 "
                    f"{device_str} available in a node "
                    f"{current_node_id=} {current_ip=}."
                )
            # This way, at least bundle is required to be created in a
            # current node.
            placement_group_specs[0][f"node:{current_ip}"] = 0.001
611
612

        # By default, Ray packs resources as much as possible.
613
        current_placement_group = ray.util.placement_group(
614
615
            placement_group_specs, strategy="PACK"
        )
616
        _wait_until_pg_ready(current_placement_group)
617

618
    assert current_placement_group is not None
619
620
621
    _verify_bundles(
        current_placement_group, parallel_config, device_str, require_gpu_on_driver
    )
622
623
    # Set the placement group in the parallel config
    parallel_config.placement_group = current_placement_group
624
625
626
627


def get_num_tpu_nodes() -> int:
    from ray._private.accelerators import TPUAcceleratorManager
628

629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
    cluster_resources = ray.cluster_resources()
    total_tpus = int(cluster_resources["TPU"])
    tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
    assert total_tpus % tpus_per_node == 0
    return total_tpus // tpus_per_node


def get_num_nodes_in_placement_group() -> int:
    pg_table = ray.util.placement_group_table()
    current_pg = ray.util.get_current_placement_group()
    num_nodes = 0

    if current_pg:
        nodes_in_pg = set()
        for pg_key, pg in pg_table.items():
            if pg_key == current_pg.id.hex():
                for _, node in pg["bundles_to_node_id"].items():
                    nodes_in_pg.add(node)
        num_nodes = len(nodes_in_pg)

    return num_nodes