ray_utils.py 5.59 KB
Newer Older
1
import pickle
2
from typing import Callable, List, Optional, Tuple
3
4

from vllm.config import ParallelConfig
5
from vllm.logger import init_logger
6
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
7
from vllm.worker.worker import Worker
8
9

logger = init_logger(__name__)
10
11
12

try:
    import ray
13

14
    class RayWorkerVllm:
15
16
17
        """Ray wrapper for vllm.worker.Worker, allowing Worker to be
        lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""

18
19
20
21
        def __init__(self, init_cached_hf_modules=False) -> None:
            if init_cached_hf_modules:
                from transformers.dynamic_module_utils import init_hf_modules
                init_hf_modules()
22
            self._worker: Optional[Worker] = None
23
24
25
26
27
            # Since the compiled DAG runs a main execution
            # in a different thread that calls cuda.set_device.
            # The flag indicates is set_device is called on
            # that thread.
            self.compiled_dag_cuda_device_set = False
28

29
30
31
32
33
34
35
        def init_worker(self, worker_init_fn: Callable[[], Worker]):
            self._worker = worker_init_fn()

        @property
        def worker(self) -> Worker:
            assert self._worker is not None
            return self._worker
36
37
38
39
40

        def __getattr__(self, name):
            return getattr(self.worker, name)

        def execute_method(self, method, *args, **kwargs):
41
42
43
44
45
46
47
48
49
50
51
            try:
                executor = getattr(self, method)
                return executor(*args, **kwargs)
            except Exception as e:
                # exceptions in ray worker may cause deadlock
                # see https://github.com/vllm-project/vllm/issues/3455
                # print the error and inform the user to solve the error
                msg = (f"Error executing method {method}. "
                       "This might cause deadlock in distributed execution.")
                logger.exception(msg)
                raise e
52

53
54
55
56
57
58
59
60
61
62
63
        def get_node_ip(self) -> str:
            return get_ip()

        def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
            node_id = ray.get_runtime_context().get_node_id()
            gpu_ids = ray.get_gpu_ids()
            return node_id, gpu_ids

        def set_cuda_visible_devices(self, device_ids) -> None:
            set_cuda_visible_devices(device_ids)

64
65
66
67
68
69
70
71
72
73
74
        def execute_model_compiled_dag_remote(self, ignored):
            """Used only when compiled DAG is enabled."""
            import torch
            if not self.compiled_dag_cuda_device_set:
                torch.cuda.set_device(self.worker.device)
                self.compiled_dag_cuda_device_set = True

            output = self.worker.execute_model()
            output = pickle.dumps(output)
            return output

75
76
77
except ImportError as e:
    logger.warning(f"Failed to import Ray with {e!r}. "
                   "For distributed inference, please install Ray with "
78
                   "`pip install ray`.")
79
80
    ray = None  # type: ignore
    RayWorkerVllm = None  # type: ignore
81
82


83
def initialize_ray_cluster(
84
    parallel_config: ParallelConfig,
Zhuohan Li's avatar
Zhuohan Li committed
85
    ray_address: Optional[str] = None,
86
87
88
89
90
91
):
    """Initialize the distributed cluster with Ray.

    it will connect to the Ray cluster and create a placement group
    for the workers, which includes the specification of the resources
    for each distributed worker.
92
93
94

    Args:
        parallel_config: The configurations for parallel execution.
Zhuohan Li's avatar
Zhuohan Li committed
95
        ray_address: The address of the Ray cluster. If None, uses
96
97
            the default Ray cluster address.
    """
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    if ray is None:
        raise ImportError(
            "Ray is not installed. Please install Ray to use distributed "
            "serving.")

    # Connect to a ray cluster.
    if is_hip():
        ray.init(address=ray_address,
                 ignore_reinit_error=True,
                 num_gpus=parallel_config.world_size)
    else:
        ray.init(address=ray_address, ignore_reinit_error=True)

    if parallel_config.placement_group:
        # Placement group is already set.
        return
114

115
    # Create placement group for worker processes
116
117
118
119
120
121
122
    current_placement_group = ray.util.get_current_placement_group()
    if current_placement_group:
        # We are in a placement group
        bundles = current_placement_group.bundle_specs
        # Verify that we can use the placement group.
        gpu_bundles = 0
        for bundle in bundles:
123
124
125
126
127
            bundle_gpus = bundle.get("GPU", 0)
            if bundle_gpus > 1:
                raise ValueError(
                    "Placement group bundle cannot have more than 1 GPU.")
            if bundle_gpus:
128
129
                gpu_bundles += 1
        if parallel_config.world_size > gpu_bundles:
130
            raise ValueError(
131
132
                "The number of required GPUs exceeds the total number of "
                "available GPUs in the placement group.")
133
    else:
134
135
        num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
        if parallel_config.world_size > num_gpus_in_cluster:
136
            raise ValueError(
137
138
139
                "The number of required GPUs exceeds the total number of "
                "available GPUs in the cluster.")
        # Create a new placement group
140
141
142
        placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
        current_placement_group = ray.util.placement_group(
            placement_group_specs)
143
144
145
146
147
        # Wait until PG is ready - this will block until all
        # requested resources are available, and will timeout
        # if they cannot be provisioned.
        ray.get(current_placement_group.ready(), timeout=1800)

148
149
    # Set the placement group in the parallel config
    parallel_config.placement_group = current_placement_group