"vllm/utils/argparse_utils.py" did not exist on "4f8f47e87e4f65195cf77b0de93feb63fc5a5b2f"
mp_distributed_executor.py 7.94 KB
Newer Older
1
import asyncio
2
from typing import Any, List, Optional
3

4
from vllm.executor.executor_base import DistributedExecutorBase
5
6
7
from vllm.executor.multiproc_worker_utils import (
    ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
    set_multiprocessing_worker_envs)
8
from vllm.logger import init_logger
9
10
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
11
12
13
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
                        get_ip, get_open_port, make_async)
from vllm.worker.worker_base import WorkerWrapperBase
14
15
16
17

logger = init_logger(__name__)


18
19
class MultiprocessingDistributedExecutor(DistributedExecutorBase):
    """Python multiprocessing-based distributed executor"""
20

21
22
    uses_ray: bool = False

23
24
    def _init_executor(self) -> None:
        # Create the parallel GPU workers.
25
26
        world_size = self.parallel_config.world_size
        tensor_parallel_size = self.parallel_config.tensor_parallel_size
27

28
29
        # Set multiprocessing envs that are common to V0 and V1
        set_multiprocessing_worker_envs(self.parallel_config)
30

31
32
33
        # Multiprocessing-based executor does not support multi-node setting.
        # Since it only works for single node, we can use the loopback address
        # 127.0.0.1 for communication.
34
        distributed_init_method = get_distributed_init_method(
35
            "127.0.0.1", get_open_port())
36

37
38
39
40
41
42
43
44
45
46
        self.workers: List[ProcessWorkerWrapper] = []
        # This is the list of workers that are rank 0 of each TP group EXCEPT
        # global rank 0. These are the workers that will broadcast to the
        # rest of the workers.
        self.tp_driver_workers: List[ProcessWorkerWrapper] = []
        # This is the list of workers that are not drivers and not the first
        # worker in a TP group. These are the workers that will be
        # broadcasted to.
        self.non_driver_workers: List[ProcessWorkerWrapper] = []

47
        if world_size == 1:
48
            self.worker_monitor = None
49
50
        else:
            result_handler = ResultHandler()
51
            for rank in range(1, world_size):
52
53
54
                worker = ProcessWorkerWrapper(result_handler,
                                              WorkerWrapperBase,
                                              self.vllm_config, rank)
55
56
57
58
59
                self.workers.append(worker)
                if rank % tensor_parallel_size == 0:
                    self.tp_driver_workers.append(worker)
                else:
                    self.non_driver_workers.append(worker)
60
61
62
63
64

            self.worker_monitor = WorkerMonitor(self.workers, result_handler)
            result_handler.start()
            self.worker_monitor.start()

65
66
67
        # Set up signal handlers to shutdown the executor cleanly
        # sometimes gc does not work well

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        self.driver_worker = WorkerWrapperBase(self.vllm_config, 0)

        all_kwargs = []
        distributed_init_method = get_distributed_init_method(
            get_ip(), get_open_port())
        for i in range(world_size):
            local_rank = i
            rank = i
            kwargs = dict(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
                rank=rank,
                distributed_init_method=distributed_init_method,
                is_driver_worker=(not self.parallel_config)
                or (rank % self.parallel_config.tensor_parallel_size == 0),
            )
            all_kwargs.append(kwargs)
        self._run_workers("init_worker", all_kwargs)
86
87
88
89
        self._run_workers("init_device")
        self._run_workers("load_model",
                          max_concurrent_workers=self.parallel_config.
                          max_parallel_loading_workers)
90
91
        self.driver_exec_model = make_async(self.driver_worker.execute_model)
        self.pp_locks: Optional[List[asyncio.Lock]] = None
92

93
94
95
96
97
    def shutdown(self):
        if (worker_monitor := getattr(self, "worker_monitor",
                                      None)) is not None:
            worker_monitor.close()

98
    def _driver_execute_model(
99
100
        self, execute_model_req: Optional[ExecuteModelRequest]
    ) -> Optional[List[SamplerOutput]]:
101
102
103
104
105
        """Run execute_model in the driver worker.

        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
106
        return self.driver_worker.execute_model(execute_model_req)
107

108
109
110
111
    def _run_workers(
        self,
        method: str,
        *args,
112
        async_run_tensor_parallel_workers_only: bool = False,
113
114
115
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
116
117
118
        """Runs the given method on all workers.

        Args:
119
120
121
122
            async_run_tensor_parallel_workers_only: If True the method will be
                run only in the remote TP workers, not the driver worker.
                It will also be run asynchronously and return a list of futures
                rather than blocking on the results.
123
        """
124
125
126
127
128

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

129
130
131
132
133
134
135
136
        if async_run_tensor_parallel_workers_only:
            # Run only non-driver workers and just return futures.
            return [
                worker.execute_method(method, *args, **kwargs)
                for worker in self.non_driver_workers
            ]

        # Start all remote workers first.
137
138
139
140
141
142
        worker_outputs = [
            worker.execute_method(method, *args, **kwargs)
            for worker in self.workers
        ]

        driver_worker_method = getattr(self.driver_worker, method)
143
        driver_worker_output = driver_worker_method(*args, **kwargs)
144
145
146
147
148
149
150

        # Get the results of the workers.
        return [driver_worker_output
                ] + [output.get() for output in worker_outputs]

    def check_health(self) -> None:
        """Raises an error if engine is unhealthy."""
151
152
        if self.worker_monitor is not None and not self.worker_monitor.is_alive(
        ):
153
154
            raise RuntimeError("Worker processes are not running")

155
156
157
158
159
160
161
162
163
164
    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        for result in parallel_worker_tasks:
            result.get()

    async def _driver_execute_model_async(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        if not self.tp_driver_workers:
            return await self.driver_exec_model(execute_model_req)

        if self.pp_locks is None:
            # This locks each pipeline parallel stage so multiple virtual
            # engines can't execute on the same stage at the same time
            # We create the locks here to avoid creating them in the constructor
            # which uses a different asyncio loop.
            self.pp_locks = [
                asyncio.Lock()
                for _ in range(self.parallel_config.pipeline_parallel_size)
            ]

        tasks = [
            asyncio.create_task(
                _run_task_with_lock(self.driver_exec_model, self.pp_locks[0],
                                    execute_model_req))
        ]
        for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
                                                start=1):
            tasks.append(
                asyncio.create_task(
                    _run_task_with_lock(driver_worker.execute_method_async,
                                        self.pp_locks[pp_rank],
                                        "execute_model", execute_model_req)))
        results = await asyncio.gather(*tasks)

        # Only the last PP stage has the final results.
        return results[-1]
194

195
196
197
    async def _start_worker_execution_loop(self):
        coros = [
            worker.execute_method_async("start_worker_execution_loop")
198
            for worker in self.non_driver_workers
199
200
        ]
        return await asyncio.gather(*coros)