cpu_executor.py 13.3 KB
Newer Older
1
2
3
import os
from functools import partial
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
4
5
6

import torch

7
import vllm.envs as envs
8
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
9
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
10
11
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
                                                  ResultHandler, WorkerMonitor)
12
13
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
14
from vllm.prompt_adapter.request import PromptAdapterRequest
15
from vllm.sequence import ExecuteModelRequest, SamplerOutput
16
17
18
from vllm.utils import (get_distributed_init_method, get_open_port,
                        get_vllm_instance_id, make_async)
from vllm.worker.worker_base import WorkerWrapperBase
19
20
21
22
23
24

logger = init_logger(__name__)


class CPUExecutor(ExecutorBase):

25
26
    uses_ray: bool = False

27
28
29
    def _init_executor(self) -> None:
        assert self.device_config.device_type == "cpu"
        assert self.lora_config is None, "cpu backend doesn't support LoRA"
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

        #
        # Environment variables for CPU executor
        #

        # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
        os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()

        # Disable torch async compiling which won't work with daemonic processes
        os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

        # Intel OpenMP setting
        ld_prealod_str = os.getenv("LD_PRELOAD", "")
        if "libiomp5.so" in ld_prealod_str:
            # The time(milliseconds) that a thread should wait after
            # completing the execution of a parallel region, before sleeping.
            os.environ['KMP_BLOCKTIME'] = "1"
            # Prevents the CPU to run into low performance state
            os.environ['KMP_TPAUSE'] = "0"
            # Provides fine granularity parallelism
            os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
            os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
            os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"

        # To hint IPEX uses shared memory based AllReduce
        os.environ["LOCAL_WORLD_SIZE"] = str(
            self.parallel_config.tensor_parallel_size)

58
59
60
61
        self.model_config = _verify_and_get_model_config(self.model_config)
        self.cache_config = _verify_and_get_cache_config(self.cache_config)
        self.scheduler_config = _verify_and_get_scheduler_config(
            self.scheduler_config)
62

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        # Multiprocessing-based executor does not support multi-node setting.
        # Since it only works for single node, we can use the loopback address
        # 127.0.0.1 for communication.
        ip = "127.0.0.1"
        port = get_open_port()
        self.distributed_init_method = get_distributed_init_method(ip, port)

        is_async = isinstance(self, CPUExecutorAsync)

        world_size = self.parallel_config.tensor_parallel_size
        result_handler = ResultHandler()
        self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
        self.workers = []

        if is_async:
            self.workers = [
                ProcessWorkerWrapper(
                    result_handler,
                    partial(
                        self._create_worker,
                        rank=rank,
                        local_rank=rank,
                    )) for rank in range(0, world_size)
            ]
            self.driver_worker = self.workers[0]
            self.workers = self.workers[1:]
            self.driver_method_invoker = _async_driver_method_invoker
        else:
            self.driver_worker = self._create_worker()
            self.driver_method_invoker = _driver_method_invoker

            if world_size != 1:
                self.workers = [
                    ProcessWorkerWrapper(
                        result_handler,
                        partial(
                            self._create_worker,
                            rank=rank,
                            local_rank=rank,
                        )) for rank in range(1, world_size)
                ]

        if world_size != 1 or is_async:
            if is_async:
                async_worker_list = self.workers + [self.driver_worker]
            else:
                async_worker_list = self.workers
            self.worker_monitor = WorkerMonitor(async_worker_list,
                                                result_handler)
            result_handler.start()
            self.worker_monitor.start()

        self._run_workers("init_device")
        self._run_workers("load_model")

    def _create_worker(
        self,
        local_rank: int = 0,
        rank: int = 0,
    ):
        worker_module_name = "vllm.worker.cpu_worker"
        worker_class_name = "CPUWorker"

        wrapper = WorkerWrapperBase(
            worker_module_name=worker_module_name,
            worker_class_name=worker_class_name,
        )
130

131
        assert self.distributed_init_method is not None
132

133
        kwargs = dict(
134
135
136
137
138
            model_config=self.model_config,
            parallel_config=self.parallel_config,
            scheduler_config=self.scheduler_config,
            device_config=self.device_config,
            cache_config=self.cache_config,
139
            load_config=self.load_config,
140
141
142
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=self.distributed_init_method,
143
            lora_config=self.lora_config,
144
            multimodal_config=self.multimodal_config,
145
            kv_cache_dtype=self.cache_config.cache_dtype,
146
            prompt_adapter_config=self.prompt_adapter_config,
147
            is_driver_worker=rank == 0,
148
        )
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        wrapper.init_worker(**kwargs)

        return wrapper.worker

    def _run_workers(
        self,
        method: str,
        *args,
        async_run_remote_workers_only: bool = False,
        max_concurrent_workers: Optional[int] = None,
        **kwargs,
    ) -> Any:
        """Runs the given method on all workers.

        Args:
            async_run_remote_workers_only: If True the method will be run only
                in the remote workers, not the driver worker. It will also be
                run asynchronously and return a list of futures rather than
                blocking on the results.
        """

        if max_concurrent_workers:
            raise NotImplementedError(
                "max_concurrent_workers is not supported yet.")

        # Start the workers first.
        worker_outputs = [
            worker.execute_method(method, *args, **kwargs)
            for worker in self.workers
        ]

        if async_run_remote_workers_only:
            # Just return futures
            return worker_outputs

        driver_worker_output = self.driver_method_invoker(
            self.driver_worker, method, *args, **kwargs)

        # Get the results of the workers.
        return [driver_worker_output
                ] + [output.get() for output in worker_outputs]
190

191
    def determine_num_available_blocks(self) -> Tuple[int, int]:
192
193
194
        """Determine the number of available KV blocks by invoking the
        underlying worker.
        """
195
196
        return self.driver_method_invoker(self.driver_worker,
                                          "determine_num_available_blocks")
197
198
199
200
201
202
203
204

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache by invoking the underlying worker.
        """
        # NOTE: We log here to avoid multiple logs when number of workers is
        # greater than one. We could log in the engine, but not all executors
        # have GPUs.
205
206
207
        # NOTE: `cpu block` for CPU backend is located on CPU memory but is
        # referred as `gpu block`. Because we want to reuse the existing block
        # management procedure.
208
        logger.info("# CPU blocks: %d", num_gpu_blocks)
209
210
211
212

        self._run_workers("initialize_cache",
                          num_gpu_blocks=num_gpu_blocks,
                          num_cpu_blocks=num_cpu_blocks)
213

214
215
216
    def execute_model(
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
217
218
219
220
221
222
223
224
        if (self.parallel_config.tensor_parallel_size > 1
                and self.parallel_worker_tasks is None):
            self.parallel_worker_tasks = self._run_workers(
                "start_worker_execution_loop",
                async_run_remote_workers_only=True,
            )
        output = self.driver_method_invoker(self.driver_worker,
                                            "execute_model", execute_model_req)
225
226
        return output

227
228
229
230
231
232
233
234
235
236
237
238
239
240
    def stop_remote_worker_execution_loop(self) -> None:
        if self.parallel_worker_tasks is None:
            return
        """
        Passing None will cause the driver to stop the model execution
        loop running in each of the remote workers.
        """
        self.driver_method_invoker(self.driver_worker, "execute_model", None)
        parallel_worker_tasks = self.parallel_worker_tasks
        self.parallel_worker_tasks = None
        # Ensure that workers exit model loop cleanly
        # (this will raise otherwise)
        self._wait_for_tasks_completion(parallel_worker_tasks)

241
    def add_lora(self, lora_request: LoRARequest) -> bool:
242
        return all(self._run_workers("add_lora", lora_request))
243
244

    def remove_lora(self, lora_id: int) -> bool:
245
        return all(self._run_workers("remove_lora", lora_id))
246

247
    def pin_lora(self, lora_id: int) -> bool:
248
249
250
251
252
        assert lora_id > 0, "lora_id must be greater than 0."
        return all(self._run_workers(
            "pin_lora",
            lora_id=lora_id,
        ))
253

254
    def list_loras(self) -> Set[int]:
255
        return self.driver_method_invoker(self.driver_worker, "list_loras")
256

257
258
    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
259
260
261
262
263
        return all(
            self._run_workers(
                "add_prompt_adapter",
                prompt_adapter_request,
            ))
264
265

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
266
267
268
269
270
        return all(
            self._run_workers(
                "remove_prompt_adapter",
                prompt_adapter_id,
            ))
271
272

    def list_prompt_adapters(self) -> Set[int]:
273
274
        return self.driver_method_invoker(self.driver_worker,
                                          "list_prompt_adapters")
275
276

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
277
278
279
280
        return all(self._run_workers(
            "pin_prompt_adapter",
            prompt_adapter_id,
        ))
281

282
    def check_health(self) -> None:
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        """Raises an error if engine is unhealthy."""
        if self.worker_monitor is not None and not self.worker_monitor.is_alive(
        ):
            raise RuntimeError("Worker processes are not running")

    def shutdown(self):
        if (worker_monitor := getattr(self, "worker_monitor",
                                      None)) is not None:
            worker_monitor.close()

    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
        """Wait for futures returned from _run_workers() with
        async_run_remote_workers_only to complete."""
        for result in parallel_worker_tasks:
            result.get()
298
299


300
301
302
class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):

    async def execute_model_async(
303
304
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
305
        output = await make_async(self.execute_model
306
                                  )(execute_model_req=execute_model_req, )
307
308
309
        return output

    async def check_health_async(self) -> None:
310
        self.check_health()
311
312


313
314
315
316
317
318
319
320
321
322
323
324
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
    if config.dtype == torch.float16:
        logger.warning("float16 is not supported on CPU, casting to bfloat16.")
        config.dtype = torch.bfloat16
    if not config.enforce_eager:
        logger.warning(
            "CUDA graph is not supported on CPU, fallback to the eager "
            "mode.")
        config.enforce_eager = True
    return config


325
326
327
328
329
330
331
332
333
def _verify_and_get_scheduler_config(
        config: SchedulerConfig) -> SchedulerConfig:
    if config.chunked_prefill_enabled:
        logger.warning("Chunked prefill is not supported on CPU, disable it.")
        config.chunked_prefill_enabled = False

    return config


334
335
336
337
338
339
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
    _GB = 1 << 30
    if config.enable_prefix_caching:
        logger.warning("Prefix caching is not supported on CPU, disable it.")
        config.enable_prefix_caching = False

340
    kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
341
342
343
344
345
346
347
348
349
350
351
352
353
354

    if kv_cache_space >= 0:
        if kv_cache_space == 0:
            config.cpu_kvcache_space_bytes = 4 * _GB  # type: ignore
            logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
                           "for CPU backend is not set, using 4 by default.")
        else:
            config.cpu_kvcache_space_bytes = kv_cache_space * _GB  # type: ignore
    else:
        raise RuntimeError(
            "Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
            f" {kv_cache_space}, expect a positive integer value.")

    return config
355
356
357
358
359
360
361
362


def _driver_method_invoker(driver, method: str, *args, **kwargs):
    return getattr(driver, method)(*args, **kwargs)


def _async_driver_method_invoker(driver, method: str, *args, **kwargs):
    return driver.execute_method(method, *args, **kwargs).get()