ray_utils.py 5.34 KB
Newer Older
1
import pickle
2
from typing import List, Optional, Tuple
3
4

from vllm.config import ParallelConfig
5
from vllm.logger import init_logger
6
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
7
8

logger = init_logger(__name__)
9
10
11

try:
    import ray
12

13
    class RayWorkerVllm:
14
15
16
        """Ray wrapper for vllm.worker.Worker, allowing Worker to be
        lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""

17
18
19
20
        def __init__(self, init_cached_hf_modules=False) -> None:
            if init_cached_hf_modules:
                from transformers.dynamic_module_utils import init_hf_modules
                init_hf_modules()
21
            self.worker = None
22
23
24
25
26
            # Since the compiled DAG runs a main execution
            # in a different thread that calls cuda.set_device.
            # The flag indicates is set_device is called on
            # that thread.
            self.compiled_dag_cuda_device_set = False
27
28
29
30
31
32
33
34

        def init_worker(self, worker_init_fn):
            self.worker = worker_init_fn()

        def __getattr__(self, name):
            return getattr(self.worker, name)

        def execute_method(self, method, *args, **kwargs):
35
36
37
38
39
40
41
42
43
44
45
            try:
                executor = getattr(self, method)
                return executor(*args, **kwargs)
            except Exception as e:
                # exceptions in ray worker may cause deadlock
                # see https://github.com/vllm-project/vllm/issues/3455
                # print the error and inform the user to solve the error
                msg = (f"Error executing method {method}. "
                       "This might cause deadlock in distributed execution.")
                logger.exception(msg)
                raise e
46

47
48
49
50
51
52
53
54
55
56
57
        def get_node_ip(self) -> str:
            return get_ip()

        def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
            node_id = ray.get_runtime_context().get_node_id()
            gpu_ids = ray.get_gpu_ids()
            return node_id, gpu_ids

        def set_cuda_visible_devices(self, device_ids) -> None:
            set_cuda_visible_devices(device_ids)

58
59
60
61
62
63
64
65
66
67
68
        def execute_model_compiled_dag_remote(self, ignored):
            """Used only when compiled DAG is enabled."""
            import torch
            if not self.compiled_dag_cuda_device_set:
                torch.cuda.set_device(self.worker.device)
                self.compiled_dag_cuda_device_set = True

            output = self.worker.execute_model()
            output = pickle.dumps(output)
            return output

69
70
71
except ImportError as e:
    logger.warning(f"Failed to import Ray with {e!r}. "
                   "For distributed inference, please install Ray with "
72
                   "`pip install ray`.")
73
    ray = None
74
    RayWorkerVllm = None
75
76


77
def initialize_ray_cluster(
78
    parallel_config: ParallelConfig,
Zhuohan Li's avatar
Zhuohan Li committed
79
    ray_address: Optional[str] = None,
80
81
82
83
84
85
):
    """Initialize the distributed cluster with Ray.

    it will connect to the Ray cluster and create a placement group
    for the workers, which includes the specification of the resources
    for each distributed worker.
86
87
88

    Args:
        parallel_config: The configurations for parallel execution.
Zhuohan Li's avatar
Zhuohan Li committed
89
        ray_address: The address of the Ray cluster. If None, uses
90
91
            the default Ray cluster address.
    """
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    if ray is None:
        raise ImportError(
            "Ray is not installed. Please install Ray to use distributed "
            "serving.")

    # Connect to a ray cluster.
    if is_hip():
        ray.init(address=ray_address,
                 ignore_reinit_error=True,
                 num_gpus=parallel_config.world_size)
    else:
        ray.init(address=ray_address, ignore_reinit_error=True)

    if parallel_config.placement_group:
        # Placement group is already set.
        return
108

109
    # Create placement group for worker processes
110
111
112
113
114
115
116
    current_placement_group = ray.util.get_current_placement_group()
    if current_placement_group:
        # We are in a placement group
        bundles = current_placement_group.bundle_specs
        # Verify that we can use the placement group.
        gpu_bundles = 0
        for bundle in bundles:
117
118
119
120
121
            bundle_gpus = bundle.get("GPU", 0)
            if bundle_gpus > 1:
                raise ValueError(
                    "Placement group bundle cannot have more than 1 GPU.")
            if bundle_gpus:
122
123
                gpu_bundles += 1
        if parallel_config.world_size > gpu_bundles:
124
            raise ValueError(
125
126
                "The number of required GPUs exceeds the total number of "
                "available GPUs in the placement group.")
127
    else:
128
129
        num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
        if parallel_config.world_size > num_gpus_in_cluster:
130
            raise ValueError(
131
132
133
                "The number of required GPUs exceeds the total number of "
                "available GPUs in the cluster.")
        # Create a new placement group
134
135
136
        placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
        current_placement_group = ray.util.placement_group(
            placement_group_specs)
137
138
139
140
141
        # Wait until PG is ready - this will block until all
        # requested resources are available, and will timeout
        # if they cannot be provisioned.
        ray.get(current_placement_group.ready(), timeout=1800)

142
143
    # Set the placement group in the parallel config
    parallel_config.placement_group = current_placement_group