mp_distributed_executor.py 9.62 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
youkaichao's avatar
youkaichao committed
4
import os
5
6
7
from typing import Any, Callable, List, Optional, Union

import cloudpickle
8

9
from vllm.executor.executor_base import DistributedExecutorBase
10
11
12
from vllm.executor.multiproc_worker_utils import (
    ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
    set_multiprocessing_worker_envs)
13
from vllm.logger import init_logger
14
15
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
youkaichao's avatar
youkaichao committed
16
17
18
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
                        get_distributed_init_method, get_ip, get_open_port,
                        make_async, run_method, update_environment_variables)
19
from vllm.worker.worker_base import WorkerWrapperBase
20
21
22
23

logger = init_logger(__name__)


24
25
class MultiprocessingDistributedExecutor(DistributedExecutorBase):
    """Python multiprocessing-based distributed executor"""
26

27
28
    uses_ray: bool = False

youkaichao's avatar
youkaichao committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    def _check_cuda(self) -> None:
        """Check that the number of GPUs is sufficient for the parallel
        configuration. Separate from _init_executor to reduce the number of
        indented blocks.
        """
        parallel_config = self.parallel_config
        world_size = parallel_config.world_size
        tensor_parallel_size = parallel_config.tensor_parallel_size

        cuda_device_count = cuda_device_count_stateless()
        # Use confusing message for more common TP-only case.
        if tensor_parallel_size > cuda_device_count:
            raise RuntimeError(
                f"please set tensor_parallel_size ({tensor_parallel_size}) "
                f"to less than max local gpu count ({cuda_device_count})")

        if world_size > cuda_device_count:
            raise RuntimeError(
                f"please ensure that world_size ({world_size}) "
                f"is less than than max local gpu count ({cuda_device_count})")

        # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
        if "CUDA_VISIBLE_DEVICES" not in os.environ:
            update_environment_variables({
                "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
            })

56
    def _init_executor(self) -> None:
youkaichao's avatar
youkaichao committed
57
58
59
60
61

        from vllm.platforms import current_platform
        if current_platform.is_cuda_alike():
            self._check_cuda()

62
        # Create the parallel GPU workers.
63
64
        world_size = self.parallel_config.world_size
        tensor_parallel_size = self.parallel_config.tensor_parallel_size
65

66
67
        # Set multiprocessing envs that are common to V0 and V1
        set_multiprocessing_worker_envs(self.parallel_config)
68

69
70
71
        # Multiprocessing-based executor does not support multi-node setting.
        # Since it only works for single node, we can use the loopback address
        # 127.0.0.1 for communication.
72
        distributed_init_method = get_distributed_init_method(
73
            "127.0.0.1", get_open_port())
74

75
76
77
78
79
80
81
82
83
84
        self.workers: List[ProcessWorkerWrapper] = []
        # This is the list of workers that are rank 0 of each TP group EXCEPT
        # global rank 0. These are the workers that will broadcast to the
        # rest of the workers.
        self.tp_driver_workers: List[ProcessWorkerWrapper] = []
        # This is the list of workers that are not drivers and not the first
        # worker in a TP group. These are the workers that will be
        # broadcasted to.
        self.non_driver_workers: List[ProcessWorkerWrapper] = []

85
        if world_size == 1:
86
            self.worker_monitor = None
87
88
        else:
            result_handler = ResultHandler()
89
            for rank in range(1, world_size):
90
91
92
                worker = ProcessWorkerWrapper(result_handler,
                                              WorkerWrapperBase,
                                              self.vllm_config, rank)
93
94
95
96
97
                self.workers.append(worker)
                if rank % tensor_parallel_size == 0:
                    self.tp_driver_workers.append(worker)
                else:
                    self.non_driver_workers.append(worker)
98
99
100
101
102

            self.worker_monitor = WorkerMonitor(self.workers, result_handler)
            result_handler.start()
            self.worker_monitor.start()

103
104
105
        # Set up signal handlers to shutdown the executor cleanly
        # sometimes gc does not work well

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        self.driver_worker = WorkerWrapperBase(self.vllm_config, 0)

        all_kwargs = []
        distributed_init_method = get_distributed_init_method(
            get_ip(), get_open_port())
        for i in range(world_size):
            local_rank = i
            rank = i
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
                rank=rank,
                distributed_init_method=distributed_init_method,
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
        self._run_workers("init_worker", all_kwargs)
124
125
126
127
        self._run_workers("init_device")
        self._run_workers("load_model",
                          max_concurrent_workers=self.parallel_config.
                          max_parallel_loading_workers)
128
129
        self.driver_exec_model = make_async(self.driver_worker.execute_model)
        self.pp_locks: Optional[List[asyncio.Lock]] = None
130

131
132
133
134
135
    def shutdown(self):
        if (worker_monitor := getattr(self, "worker_monitor",
                                      None)) is not None:
            worker_monitor.close()

136
    def _driver_execute_model(
137
138
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
139
140
141
142
143
        """Run execute_model in the driver worker.

        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
144
        return self.driver_worker.execute_model(execute_model_req)
145

146
147
    def _run_workers(
        self,
148
        method: Union[str, Callable],
149
        *args,
150
        async_run_tensor_parallel_workers_only: bool = False,
151
152
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
153
    ) -> List[Any]:
154
155
156
        """Runs the given method on all workers.

        Args:
157
158
159
160
            async_run_tensor_parallel_workers_only: If True the method will be
                run only in the remote TP workers, not the driver worker.
                It will also be run asynchronously and return a list of futures
                rather than blocking on the results.
161
        """
162
163
164
165
166
        if isinstance(method, str):
            sent_method = method
        else:
            sent_method = cloudpickle.dumps(method)
        del method
167
168
169
170
171

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

172
173
174
        if async_run_tensor_parallel_workers_only:
            # Run only non-driver workers and just return futures.
            return [
175
                worker.execute_method(sent_method, *args, **kwargs)
176
177
178
179
                for worker in self.non_driver_workers
            ]

        # Start all remote workers first.
180
        worker_outputs = [
181
            worker.execute_method(sent_method, *args, **kwargs)
182
183
184
            for worker in self.workers
        ]

185
186
        driver_worker_output = run_method(self.driver_worker, sent_method,
                                          args, kwargs)
187
188
189
190
191
192
193

        # Get the results of the workers.
        return [driver_worker_output
                ] + [output.get() for output in worker_outputs]

    def check_health(self) -> None:
        """Raises an error if engine is unhealthy."""
194
195
        if self.worker_monitor is not None and not self.worker_monitor.is_alive(
        ):
196
197
            raise RuntimeError("Worker processes are not running")

198
199
200
201
202
203
204
205
206
207
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        for result in parallel_worker_tasks:
            result.get()

    async def _driver_execute_model_async(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        if not self.tp_driver_workers:
            return await self.driver_exec_model(execute_model_req)

        if self.pp_locks is None:
            # This locks each pipeline parallel stage so multiple virtual
            # engines can't execute on the same stage at the same time
            # We create the locks here to avoid creating them in the constructor
            # which uses a different asyncio loop.
            self.pp_locks = [
                asyncio.Lock()
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]

        tasks = [
            asyncio.create_task(
                _run_task_with_lock(self.driver_exec_model, self.pp_locks[0],
                                    execute_model_req))
        ]
        for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
                                                start=1):
            tasks.append(
                asyncio.create_task(
                    _run_task_with_lock(driver_worker.execute_method_async,
                                        self.pp_locks[pp_rank],
                                        "execute_model", execute_model_req)))
        results = await asyncio.gather(*tasks)

        # Only the last PP stage has the final results.
        return results[-1]
237

238
239
240
    async def _start_worker_execution_loop(self):
        coros = [
            worker.execute_method_async("start_worker_execution_loop")
241
            for worker in self.non_driver_workers
242
243
        ]
        return await asyncio.gather(*coros)