cpu.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
import platform
6
import sys
7
from importlib.util import find_spec
8
from typing import TYPE_CHECKING, Optional
9

10
import psutil
11
12
import torch

13
from vllm.logger import init_logger
14
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
15

16
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
17
18

logger = init_logger(__name__)
19

20
21
22
23
24
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

25

26
27
28
29
30
31
32
33
34
def get_max_threads(pid=0):
    if hasattr(os, 'sched_getaffinity'):
        return len(os.sched_getaffinity(pid))
    elif platform.system() == 'Darwin':
        return os.cpu_count()
    else:
        raise NotImplementedError("Unsupported OS")


35
36
class CpuPlatform(Platform):
    _enum = PlatformEnum.CPU
37
    device_name: str = "cpu"
38
    device_type: str = "cpu"
39
    dispatch_key: str = "CPU"
40

41
    @property
42
    def supported_dtypes(self) -> list[torch.dtype]:
43
44
45
46
47
48
49
50
        if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
            return [torch.bfloat16, torch.float32]
        elif sys.platform.startswith(
                "darwin") and self.get_cpu_architecture() == CpuArchEnum.ARM:
            # TODO: change this condition to check if the platform support bf16
            # instead of checking the OS. For instance M2 shall supports bf16
            # already. But we need to modify `cpu_extension.cmake` to activate
            # the feature in the build.
51
            return [torch.float16, torch.float32]
52
53
54
        # x86/aarch64 CPU has supported both bf16 and fp16 natively.
        return [torch.bfloat16, torch.float16, torch.float32]

55
56
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
57
58
        return "cpu"

59
    @classmethod
60
61
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
62
63
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
64
        if selected_backend and selected_backend != _Backend.TORCH_SDPA:
65
            logger.info("Cannot use %s backend on CPU.", selected_backend)
Thien Tran's avatar
Thien Tran committed
66
67
68
        if use_mla:
            logger.info("Using CPU MLA backend.")
            return "vllm.attention.backends.cpu_mla.CPUMLABackend"
69
        logger.info("Using Torch SDPA backend.")
70
71
72
73
        if use_v1:
            return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
        else:
            return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
74

75
76
77
78
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        return psutil.virtual_memory().total

79
80
81
82
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        return False

83
84
    @classmethod
    def inference_mode(cls):
85
        return torch.no_grad()
86
87
88
89
90
91
92

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        import vllm.envs as envs
        from vllm.utils import GiB_bytes
        model_config = vllm_config.model_config

93
94
        model_config.disable_cascade_attn = True

95
96
        cache_config = vllm_config.cache_config

97
        ipex_available = find_spec("intel_extension_for_pytorch") is not None
98

99
        if cache_config and cache_config.block_size is None:
100
            cache_config.block_size = 128 if ipex_available else 16
101

102
        if not ipex_available and cache_config.block_size != 16:
103
104
105
            raise RuntimeError(
                f"--block-size={cache_config.block_size} requires"
                " intel_extension_for_pytorch")
106

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        scheduler_config = vllm_config.scheduler_config
        if ((scheduler_config.chunked_prefill_enabled
             or cache_config.enable_prefix_caching)
                and cache_config.cache_dtype != "auto"):
            raise RuntimeError("Chunked-prefill and prefix-cache on the CPU "
                               "backend is not compatible with FP8 KV cache.")

        if cache_config.cache_dtype == "fp8_e4m3":
            cache_config.cache_dtype = "fp8_e5m2"
            logger.warning(
                "CPU backend doesn't support fp8_e4m3 KV cache type, "
                "cast to fp8_e5m2.")

        if (cache_config.cache_dtype != "auto"
                and model_config.dtype == torch.half):
            logger.warning("FP8 KV cache on the CPU backend only does not"
                           " support fp16 for now, cast to bf16.")
            model_config.dtype = torch.bfloat16

126
127
128
129
130
131
        kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE

        if kv_cache_space >= 0:
            if kv_cache_space == 0:
                cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes  # type: ignore
                logger.warning(
132
                    "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
133
134
135
136
137
138
139
140
141
                    "for CPU backend is not set, using 4 by default.")
            else:
                cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes  # type: ignore # noqa
        else:
            raise RuntimeError(
                "Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
                f" {kv_cache_space}, expect a positive integer value.")

        parallel_config = vllm_config.parallel_config
142
143
        if (parallel_config.world_size > 1
                and parallel_config.distributed_executor_backend is not None
144
145
146
147
148
                and parallel_config.distributed_executor_backend != "mp"):
            logger.warning(("%s is not supported on CPU, fallback to mp "
                            "distributed executor backend."),
                           parallel_config.distributed_executor_backend)
            parallel_config.distributed_executor_backend = "mp"
149
        if parallel_config.worker_cls == "auto":
150
151
152
153
154
155
            if vllm_config.speculative_config:
                parallel_config.worker_cls = \
                    "vllm.spec_decode.spec_decode_worker.create_spec_worker"
                parallel_config.sd_worker_cls = \
                    "vllm.worker.cpu_worker.CPUWorker"
            else:
156
157
158
159
160
161
162
163
164
165
166
167
168
169
                if envs.VLLM_USE_V1:
                    parallel_config.worker_cls = \
                        "vllm.v1.worker.cpu_worker.CPUWorker"
                else:
                    parallel_config.worker_cls = \
                        "vllm.worker.cpu_worker.CPUWorker"

        # Note: workaround for v1 gpu_model_runner
        from vllm.config import CompilationLevel
        vllm_config.compilation_config.cudagraph_capture_sizes = []

        compilation_config = vllm_config.compilation_config
        if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level
                == CompilationLevel.PIECEWISE):
170
171
172
173
174
175
176
177
178
179
180
181
182

            # Note: vLLM V1 is using PIECEWISE level compilation, which will
            # take time to compile kernels just-in-time with the inductor
            # backend. For CPU CI tests, most of them are executed fast and
            # compilations consume too much time, even with torch compile
            # cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
            # and just execute model with dynamo + eager mode to save time.
            # VLLM_CPU_CI_ENV is only used as an internal variable.
            if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0":
                backend = "eager"
            else:
                backend = "inductor"

183
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
184
            compilation_config.backend = backend
185
186
187
188
189
190
191
192
193
194
195
196
            compilation_config.inductor_compile_config.update({
                "dce":
                True,
                "size_asserts":
                False,
                "nan_asserts":
                False,
                "memory_planning":
                True,
                "epilogue_fusion":
                True,
            })
197
198
            if compilation_config.use_inductor:
                compilation_config.custom_ops = ["none"]
199
200
201

        if vllm_config.lora_config is not None:
            compilation_config.level = CompilationLevel.NO_COMPILATION
202

203
204
205
206
207
208
        assert vllm_config.device_config.device_type == "cpu"

        #
        # Environment variables for CPU executor
        #

209
210
211
212
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        # Note: to avoid the error 'nthreads cannot be larger than environment
        #  variable "NUMEXPR_MAX_THREADS" (64)'.
213
        os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads())
214

215
216
217
        # Set default threads num for OpenMP parallel
        os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        # Disable torch async compiling which won't work with daemonic processes
        os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

        # Intel OpenMP setting
        ld_prealod_str = os.getenv("LD_PRELOAD", "")
        if "libiomp5.so" in ld_prealod_str:
            # The time(milliseconds) that a thread should wait after
            # completing the execution of a parallel region, before sleeping.
            os.environ['KMP_BLOCKTIME'] = "1"
            # Prevents the CPU to run into low performance state
            os.environ['KMP_TPAUSE'] = "0"
            # Provides fine granularity parallelism
            os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
            os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
            os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"

        # To hint IPEX uses shared memory based AllReduce
        os.environ["LOCAL_WORLD_SIZE"] = str(
            vllm_config.parallel_config.tensor_parallel_size)

238
239
240
241
242
243
244
245
246
247
        if vllm_config.model_config and vllm_config.model_config.use_mla:
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
                "prefill and prefix caching to be disabled.")
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
                DEFAULT_MAX_NUM_BATCHED_TOKENS)

248
249
250
251
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        logger.warning("Pin memory is not supported on CPU.")
        return False
252
253
254
255

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
256
257
258
259
260
261
262

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
        return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"  # noqa
263
264
265
266
267
268
269
270
271
272
273

    @classmethod
    def supports_structured_output(cls) -> bool:
        return True

    @classmethod
    def supports_v1(cls, model_config) -> bool:
        """Returns whether the current platform can support v1 for the supplied
        model configuration.
        """
        return True
274
275
276
277
278
279
280
281

    @classmethod
    def default_v1(cls, model_config) -> bool:
        """Returns whether the current platform can use v1 by default for the
        supplied model configuration.
        """
        return cls.supports_v1(
            model_config) and cls.get_cpu_architecture() == CpuArchEnum.X86