mp_distributed_executor.py 8.16 KB
Newer Older
1
import asyncio
2
3
4
from typing import Any, Callable, List, Optional, Union

import cloudpickle
5

6
from vllm.executor.executor_base import DistributedExecutorBase
7
8
9
from vllm.executor.multiproc_worker_utils import (
    ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
    set_multiprocessing_worker_envs)
10
from vllm.logger import init_logger
11
12
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
13
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
14
                        get_ip, get_open_port, make_async, run_method)
15
from vllm.worker.worker_base import WorkerWrapperBase
16
17
18
19

logger = init_logger(__name__)


20
21
class MultiprocessingDistributedExecutor(DistributedExecutorBase):
    """Python multiprocessing-based distributed executor"""
22

23
24
    uses_ray: bool = False

25
26
    def _init_executor(self) -> None:
        # Create the parallel GPU workers.
27
28
        world_size = self.parallel_config.world_size
        tensor_parallel_size = self.parallel_config.tensor_parallel_size
29

30
31
        # Set multiprocessing envs that are common to V0 and V1
        set_multiprocessing_worker_envs(self.parallel_config)
32

33
34
35
        # Multiprocessing-based executor does not support multi-node setting.
        # Since it only works for single node, we can use the loopback address
        # 127.0.0.1 for communication.
36
        distributed_init_method = get_distributed_init_method(
37
            "127.0.0.1", get_open_port())
38

39
40
41
42
43
44
45
46
47
48
        self.workers: List[ProcessWorkerWrapper] = []
        # This is the list of workers that are rank 0 of each TP group EXCEPT
        # global rank 0. These are the workers that will broadcast to the
        # rest of the workers.
        self.tp_driver_workers: List[ProcessWorkerWrapper] = []
        # This is the list of workers that are not drivers and not the first
        # worker in a TP group. These are the workers that will be
        # broadcasted to.
        self.non_driver_workers: List[ProcessWorkerWrapper] = []

49
        if world_size == 1:
50
            self.worker_monitor = None
51
52
        else:
            result_handler = ResultHandler()
53
            for rank in range(1, world_size):
54
55
56
                worker = ProcessWorkerWrapper(result_handler,
                                              WorkerWrapperBase,
                                              self.vllm_config, rank)
57
58
59
60
61
                self.workers.append(worker)
                if rank % tensor_parallel_size == 0:
                    self.tp_driver_workers.append(worker)
                else:
                    self.non_driver_workers.append(worker)
62
63
64
65
66

            self.worker_monitor = WorkerMonitor(self.workers, result_handler)
            result_handler.start()
            self.worker_monitor.start()

67
68
69
        # Set up signal handlers to shutdown the executor cleanly
        # sometimes gc does not work well

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        self.driver_worker = WorkerWrapperBase(self.vllm_config, 0)

        all_kwargs = []
        distributed_init_method = get_distributed_init_method(
            get_ip(), get_open_port())
        for i in range(world_size):
            local_rank = i
            rank = i
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
                rank=rank,
                distributed_init_method=distributed_init_method,
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
        self._run_workers("init_worker", all_kwargs)
88
89
90
91
        self._run_workers("init_device")
        self._run_workers("load_model",
                          max_concurrent_workers=self.parallel_config.
                          max_parallel_loading_workers)
92
93
        self.driver_exec_model = make_async(self.driver_worker.execute_model)
        self.pp_locks: Optional[List[asyncio.Lock]] = None
94

95
96
97
98
99
    def shutdown(self):
        if (worker_monitor := getattr(self, "worker_monitor",
                                      None)) is not None:
            worker_monitor.close()

100
    def _driver_execute_model(
101
102
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
103
104
105
106
107
        """Run execute_model in the driver worker.

        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
108
        return self.driver_worker.execute_model(execute_model_req)
109

110
111
    def _run_workers(
        self,
112
        method: Union[str, Callable],
113
        *args,
114
        async_run_tensor_parallel_workers_only: bool = False,
115
116
117
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
118
119
120
        """Runs the given method on all workers.

        Args:
121
122
123
124
            async_run_tensor_parallel_workers_only: If True the method will be
                run only in the remote TP workers, not the driver worker.
                It will also be run asynchronously and return a list of futures
                rather than blocking on the results.
125
        """
126
127
128
129
130
        if isinstance(method, str):
            sent_method = method
        else:
            sent_method = cloudpickle.dumps(method)
        del method
131
132
133
134
135

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

136
137
138
        if async_run_tensor_parallel_workers_only:
            # Run only non-driver workers and just return futures.
            return [
139
                worker.execute_method(sent_method, *args, **kwargs)
140
141
142
143
                for worker in self.non_driver_workers
            ]

        # Start all remote workers first.
144
        worker_outputs = [
145
            worker.execute_method(sent_method, *args, **kwargs)
146
147
148
            for worker in self.workers
        ]

149
150
        driver_worker_output = run_method(self.driver_worker, sent_method,
                                          args, kwargs)
151
152
153
154
155
156
157

        # Get the results of the workers.
        return [driver_worker_output
                ] + [output.get() for output in worker_outputs]

    def check_health(self) -> None:
        """Raises an error if engine is unhealthy."""
158
159
        if self.worker_monitor is not None and not self.worker_monitor.is_alive(
        ):
160
161
            raise RuntimeError("Worker processes are not running")

162
163
164
165
166
167
168
169
170
171
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        for result in parallel_worker_tasks:
            result.get()

    async def _driver_execute_model_async(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        if not self.tp_driver_workers:
            return await self.driver_exec_model(execute_model_req)

        if self.pp_locks is None:
            # This locks each pipeline parallel stage so multiple virtual
            # engines can't execute on the same stage at the same time
            # We create the locks here to avoid creating them in the constructor
            # which uses a different asyncio loop.
            self.pp_locks = [
                asyncio.Lock()
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]

        tasks = [
            asyncio.create_task(
                _run_task_with_lock(self.driver_exec_model, self.pp_locks[0],
                                    execute_model_req))
        ]
        for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
                                                start=1):
            tasks.append(
                asyncio.create_task(
                    _run_task_with_lock(driver_worker.execute_method_async,
                                        self.pp_locks[pp_rank],
                                        "execute_model", execute_model_req)))
        results = await asyncio.gather(*tasks)

        # Only the last PP stage has the final results.
        return results[-1]
201

202
203
204
    async def _start_worker_execution_loop(self):
        coros = [
            worker.execute_method_async("start_worker_execution_loop")
205
            for worker in self.non_driver_workers
206
207
        ]
        return await asyncio.gather(*coros)