ray_utils.py 5.34 KB
Newer Older
1
2
import pickle

3
from typing import Optional, List, Tuple
4
5

from vllm.config import ParallelConfig
6
from vllm.logger import init_logger
7
from vllm.utils import is_hip, set_cuda_visible_devices, get_ip
8
9

logger = init_logger(__name__)
10
11
12

try:
    import ray
13

14
    class RayWorkerVllm:
15
16
17
        """Ray wrapper for vllm.worker.Worker, allowing Worker to be
        lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""

18
19
20
21
        def __init__(self, init_cached_hf_modules=False) -> None:
            if init_cached_hf_modules:
                from transformers.dynamic_module_utils import init_hf_modules
                init_hf_modules()
22
            self.worker = None
23
24
25
26
27
            # Since the compiled DAG runs a main execution
            # in a different thread that calls cuda.set_device.
            # The flag indicates is set_device is called on
            # that thread.
            self.compiled_dag_cuda_device_set = False
28
29
30
31
32
33
34
35

        def init_worker(self, worker_init_fn):
            self.worker = worker_init_fn()

        def __getattr__(self, name):
            return getattr(self.worker, name)

        def execute_method(self, method, *args, **kwargs):
36
37
38
39
40
41
42
43
44
45
46
            try:
                executor = getattr(self, method)
                return executor(*args, **kwargs)
            except Exception as e:
                # exceptions in ray worker may cause deadlock
                # see https://github.com/vllm-project/vllm/issues/3455
                # print the error and inform the user to solve the error
                msg = (f"Error executing method {method}. "
                       "This might cause deadlock in distributed execution.")
                logger.exception(msg)
                raise e
47

48
49
50
51
52
53
54
55
56
57
58
        def get_node_ip(self) -> str:
            return get_ip()

        def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
            node_id = ray.get_runtime_context().get_node_id()
            gpu_ids = ray.get_gpu_ids()
            return node_id, gpu_ids

        def set_cuda_visible_devices(self, device_ids) -> None:
            set_cuda_visible_devices(device_ids)

59
60
61
62
63
64
65
66
67
68
69
        def execute_model_compiled_dag_remote(self, ignored):
            """Used only when compiled DAG is enabled."""
            import torch
            if not self.compiled_dag_cuda_device_set:
                torch.cuda.set_device(self.worker.device)
                self.compiled_dag_cuda_device_set = True

            output = self.worker.execute_model()
            output = pickle.dumps(output)
            return output

70
71
72
except ImportError as e:
    logger.warning(f"Failed to import Ray with {e!r}. "
                   "For distributed inference, please install Ray with "
73
                   "`pip install ray`.")
74
    ray = None
75
    RayWorkerVllm = None
76
77


78
def initialize_ray_cluster(
79
    parallel_config: ParallelConfig,
Zhuohan Li's avatar
Zhuohan Li committed
80
    ray_address: Optional[str] = None,
81
82
83
84
85
86
):
    """Initialize the distributed cluster with Ray.

    it will connect to the Ray cluster and create a placement group
    for the workers, which includes the specification of the resources
    for each distributed worker.
87
88
89

    Args:
        parallel_config: The configurations for parallel execution.
Zhuohan Li's avatar
Zhuohan Li committed
90
        ray_address: The address of the Ray cluster. If None, uses
91
92
            the default Ray cluster address.
    """
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    if ray is None:
        raise ImportError(
            "Ray is not installed. Please install Ray to use distributed "
            "serving.")

    # Connect to a ray cluster.
    if is_hip():
        ray.init(address=ray_address,
                 ignore_reinit_error=True,
                 num_gpus=parallel_config.world_size)
    else:
        ray.init(address=ray_address, ignore_reinit_error=True)

    if parallel_config.placement_group:
        # Placement group is already set.
        return
109

110
    # Create placement group for worker processes
111
112
113
114
115
116
117
    current_placement_group = ray.util.get_current_placement_group()
    if current_placement_group:
        # We are in a placement group
        bundles = current_placement_group.bundle_specs
        # Verify that we can use the placement group.
        gpu_bundles = 0
        for bundle in bundles:
118
119
120
121
122
            bundle_gpus = bundle.get("GPU", 0)
            if bundle_gpus > 1:
                raise ValueError(
                    "Placement group bundle cannot have more than 1 GPU.")
            if bundle_gpus:
123
124
                gpu_bundles += 1
        if parallel_config.world_size > gpu_bundles:
125
            raise ValueError(
126
127
                "The number of required GPUs exceeds the total number of "
                "available GPUs in the placement group.")
128
    else:
129
130
        num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
        if parallel_config.world_size > num_gpus_in_cluster:
131
            raise ValueError(
132
133
134
                "The number of required GPUs exceeds the total number of "
                "available GPUs in the cluster.")
        # Create a new placement group
135
136
137
        placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
        current_placement_group = ray.util.placement_group(
            placement_group_specs)
138
139
140
141
142
        # Wait until PG is ready - this will block until all
        # requested resources are available, and will timeout
        # if they cannot be provisioned.
        ray.get(current_placement_group.ready(), timeout=1800)

143
144
    # Set the placement group in the parallel config
    parallel_config.placement_group = current_placement_group