ray_utils.py 4.37 KB
Newer Older
1
import pickle
2
from typing import List, Optional, Tuple
3
4

from vllm.config import ParallelConfig
5
from vllm.logger import init_logger
6
7
from vllm.utils import get_ip, is_hip
from vllm.worker.worker_base import WorkerWrapperBase
8
9

logger = init_logger(__name__)
10
11
12

try:
    import ray
13

14
    class RayWorkerWrapper(WorkerWrapperBase):
15
16
17
        """Ray wrapper for vllm.worker.Worker, allowing Worker to be
        lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""

18
19
        def __init__(self, *args, **kwargs) -> None:
            super().__init__(*args, **kwargs)
20
21
22
23
24
            # Since the compiled DAG runs a main execution
            # in a different thread that calls cuda.set_device.
            # The flag indicates is set_device is called on
            # that thread.
            self.compiled_dag_cuda_device_set = False
25

26
27
28
29
30
31
32
33
        def get_node_ip(self) -> str:
            return get_ip()

        def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
            node_id = ray.get_runtime_context().get_node_id()
            gpu_ids = ray.get_gpu_ids()
            return node_id, gpu_ids

34
35
36
37
38
39
40
41
42
43
44
        def execute_model_compiled_dag_remote(self, ignored):
            """Used only when compiled DAG is enabled."""
            import torch
            if not self.compiled_dag_cuda_device_set:
                torch.cuda.set_device(self.worker.device)
                self.compiled_dag_cuda_device_set = True

            output = self.worker.execute_model()
            output = pickle.dumps(output)
            return output

45
except ImportError as e:
46
47
48
    logger.warning(
        "Failed to import Ray with %r. For distributed inference, "
        "please install Ray with `pip install ray`.", e)
49
    ray = None  # type: ignore
50
    RayWorkerWrapper = None  # type: ignore
51
52


53
def initialize_ray_cluster(
54
    parallel_config: ParallelConfig,
Zhuohan Li's avatar
Zhuohan Li committed
55
    ray_address: Optional[str] = None,
56
57
58
59
60
61
):
    """Initialize the distributed cluster with Ray.

    it will connect to the Ray cluster and create a placement group
    for the workers, which includes the specification of the resources
    for each distributed worker.
62
63
64

    Args:
        parallel_config: The configurations for parallel execution.
Zhuohan Li's avatar
Zhuohan Li committed
65
        ray_address: The address of the Ray cluster. If None, uses
66
67
            the default Ray cluster address.
    """
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    if ray is None:
        raise ImportError(
            "Ray is not installed. Please install Ray to use distributed "
            "serving.")

    # Connect to a ray cluster.
    if is_hip():
        ray.init(address=ray_address,
                 ignore_reinit_error=True,
                 num_gpus=parallel_config.world_size)
    else:
        ray.init(address=ray_address, ignore_reinit_error=True)

    if parallel_config.placement_group:
        # Placement group is already set.
        return
84

85
    # Create placement group for worker processes
86
87
88
89
90
91
92
    current_placement_group = ray.util.get_current_placement_group()
    if current_placement_group:
        # We are in a placement group
        bundles = current_placement_group.bundle_specs
        # Verify that we can use the placement group.
        gpu_bundles = 0
        for bundle in bundles:
93
94
95
96
97
            bundle_gpus = bundle.get("GPU", 0)
            if bundle_gpus > 1:
                raise ValueError(
                    "Placement group bundle cannot have more than 1 GPU.")
            if bundle_gpus:
98
99
                gpu_bundles += 1
        if parallel_config.world_size > gpu_bundles:
100
            raise ValueError(
101
102
                "The number of required GPUs exceeds the total number of "
                "available GPUs in the placement group.")
103
    else:
104
105
        num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
        if parallel_config.world_size > num_gpus_in_cluster:
106
            raise ValueError(
107
108
109
                "The number of required GPUs exceeds the total number of "
                "available GPUs in the cluster.")
        # Create a new placement group
110
111
112
        placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
        current_placement_group = ray.util.placement_group(
            placement_group_specs)
113
114
115
116
117
        # Wait until PG is ready - this will block until all
        # requested resources are available, and will timeout
        # if they cannot be provisioned.
        ray.get(current_placement_group.ready(), timeout=1800)

118
119
    # Set the placement group in the parallel config
    parallel_config.placement_group = current_placement_group