cpu.py 12.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
import os
6
import platform
7
import subprocess
8
import sys
9
from dataclasses import dataclass
10
from importlib.util import find_spec
11
from typing import TYPE_CHECKING, Optional
12

13
14
import torch

15
from vllm.logger import init_logger
16
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
17

18
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
19
20

logger = init_logger(__name__)
21

22
23
24
25
26
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

27

28
29
30
31
32
33
34
35
36
def get_max_threads(pid=0):
    if hasattr(os, 'sched_getaffinity'):
        return len(os.sched_getaffinity(pid))
    elif platform.system() == 'Darwin':
        return os.cpu_count()
    else:
        raise NotImplementedError("Unsupported OS")


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@dataclass
class LogicalCPUInfo:
    id: int = -1
    physical_core: int = -1
    numa_node: int = -1

    @classmethod
    def _int(cls, value: str) -> int:
        try:
            int_value = int(value)
        except Exception:
            int_value = -1
        return int_value

    @staticmethod
    def json_decoder(obj_dict: dict):
        id = obj_dict.get("cpu")
        physical_core = obj_dict.get("core")
        numa_node = obj_dict.get("node")

        if not (id is None or physical_core is None or numa_node is None):
            return LogicalCPUInfo(
                id=LogicalCPUInfo._int(id),
                physical_core=LogicalCPUInfo._int(physical_core),
                numa_node=LogicalCPUInfo._int(numa_node))
        else:
            return obj_dict


66
67
class CpuPlatform(Platform):
    _enum = PlatformEnum.CPU
68
    device_name: str = "cpu"
69
    device_type: str = "cpu"
70
    dispatch_key: str = "CPU"
71
    dist_backend: str = "gloo"
72

73
    @property
74
    def supported_dtypes(self) -> list[torch.dtype]:
75
76
77
78
79
80
81
82
        if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
            return [torch.bfloat16, torch.float32]
        elif sys.platform.startswith(
                "darwin") and self.get_cpu_architecture() == CpuArchEnum.ARM:
            # TODO: change this condition to check if the platform support bf16
            # instead of checking the OS. For instance M2 shall supports bf16
            # already. But we need to modify `cpu_extension.cmake` to activate
            # the feature in the build.
83
            return [torch.float16, torch.float32]
84
85
86
        # x86/aarch64 CPU has supported both bf16 and fp16 natively.
        return [torch.bfloat16, torch.float16, torch.float32]

87
88
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
89
90
        return "cpu"

91
    @classmethod
92
93
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
94
95
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
96
        if selected_backend and selected_backend != _Backend.TORCH_SDPA:
97
            logger.info("Cannot use %s backend on CPU.", selected_backend)
Thien Tran's avatar
Thien Tran committed
98
        if use_mla:
99
            raise NotImplementedError("MLA is not supported on CPU.")
100
        logger.info("Using Torch SDPA backend.")
101
102
103
        if not use_v1:
            raise ValueError("CPU backend only supports V1.")
        return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
104

105
106
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
107
        import psutil
108
109
        return psutil.virtual_memory().total

110
111
112
113
114
115
116
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cpu.set_device(device)

117
118
119
120
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        return False

121
122
    @classmethod
    def inference_mode(cls):
123
        return torch.no_grad()
124
125
126
127
128
129
130

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        import vllm.envs as envs
        from vllm.utils import GiB_bytes
        model_config = vllm_config.model_config

131
132
        if model_config is not None:
            model_config.disable_cascade_attn = True
133

134
135
        cache_config = vllm_config.cache_config

136
        ipex_available = find_spec("intel_extension_for_pytorch") is not None
137

138
        if cache_config and cache_config.block_size is None:
139
            cache_config.block_size = 128 if ipex_available else 16
140

141
        if not ipex_available and cache_config.block_size != 16:
142
143
144
            raise RuntimeError(
                f"--block-size={cache_config.block_size} requires"
                " intel_extension_for_pytorch")
145

146
147
148
149
150
151
152
153
154
155
156
157
158
        scheduler_config = vllm_config.scheduler_config
        if ((scheduler_config.chunked_prefill_enabled
             or cache_config.enable_prefix_caching)
                and cache_config.cache_dtype != "auto"):
            raise RuntimeError("Chunked-prefill and prefix-cache on the CPU "
                               "backend is not compatible with FP8 KV cache.")

        if cache_config.cache_dtype == "fp8_e4m3":
            cache_config.cache_dtype = "fp8_e5m2"
            logger.warning(
                "CPU backend doesn't support fp8_e4m3 KV cache type, "
                "cast to fp8_e5m2.")

159
        if (cache_config.cache_dtype != "auto" and model_config is not None
160
161
162
163
164
                and model_config.dtype == torch.half):
            logger.warning("FP8 KV cache on the CPU backend only does not"
                           " support fp16 for now, cast to bf16.")
            model_config.dtype = torch.bfloat16

165
166
167
168
169
170
        kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE

        if kv_cache_space >= 0:
            if kv_cache_space == 0:
                cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes  # type: ignore
                logger.warning(
171
                    "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
172
173
174
175
176
177
178
179
180
                    "for CPU backend is not set, using 4 by default.")
            else:
                cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes  # type: ignore # noqa
        else:
            raise RuntimeError(
                "Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
                f" {kv_cache_space}, expect a positive integer value.")

        parallel_config = vllm_config.parallel_config
181
182
        if (parallel_config.world_size > 1
                and parallel_config.distributed_executor_backend is not None
183
184
185
186
187
                and parallel_config.distributed_executor_backend != "mp"):
            logger.warning(("%s is not supported on CPU, fallback to mp "
                            "distributed executor backend."),
                           parallel_config.distributed_executor_backend)
            parallel_config.distributed_executor_backend = "mp"
188
        if parallel_config.worker_cls == "auto":
189
            parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
190
191
192
193
194
195

        # Note: workaround for v1 gpu_model_runner
        from vllm.config import CompilationLevel
        vllm_config.compilation_config.cudagraph_capture_sizes = []

        compilation_config = vllm_config.compilation_config
196
        if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
197
198
199
200
201
202
203
204
205
206
207
208
209

            # Note: vLLM V1 is using PIECEWISE level compilation, which will
            # take time to compile kernels just-in-time with the inductor
            # backend. For CPU CI tests, most of them are executed fast and
            # compilations consume too much time, even with torch compile
            # cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
            # and just execute model with dynamo + eager mode to save time.
            # VLLM_CPU_CI_ENV is only used as an internal variable.
            if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0":
                backend = "eager"
            else:
                backend = "inductor"

210
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
211
            compilation_config.backend = backend
212
213
214
215
216
217
218
219
220
221
222
223
            compilation_config.inductor_compile_config.update({
                "dce":
                True,
                "size_asserts":
                False,
                "nan_asserts":
                False,
                "memory_planning":
                True,
                "epilogue_fusion":
                True,
            })
224
225
            if compilation_config.use_inductor:
                compilation_config.custom_ops = ["none"]
226
227
228

        if vllm_config.lora_config is not None:
            compilation_config.level = CompilationLevel.NO_COMPILATION
229

230
231
232
233
234
235
        assert vllm_config.device_config.device_type == "cpu"

        #
        # Environment variables for CPU executor
        #

236
237
238
239
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        # Note: to avoid the error 'nthreads cannot be larger than environment
        #  variable "NUMEXPR_MAX_THREADS" (64)'.
240
        os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads())
241

242
243
244
        # Set default threads num for OpenMP parallel
        os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        # Disable torch async compiling which won't work with daemonic processes
        os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

        # Intel OpenMP setting
        ld_prealod_str = os.getenv("LD_PRELOAD", "")
        if "libiomp5.so" in ld_prealod_str:
            # The time(milliseconds) that a thread should wait after
            # completing the execution of a parallel region, before sleeping.
            os.environ['KMP_BLOCKTIME'] = "1"
            # Prevents the CPU to run into low performance state
            os.environ['KMP_TPAUSE'] = "0"
            # Provides fine granularity parallelism
            os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
            os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
            os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"

        # To hint IPEX uses shared memory based AllReduce
        os.environ["LOCAL_WORLD_SIZE"] = str(
            vllm_config.parallel_config.tensor_parallel_size)

265
        if model_config is not None and model_config.use_mla:
266
267
268
269
270
271
272
273
274
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
                "prefill and prefix caching to be disabled.")
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
                DEFAULT_MAX_NUM_BATCHED_TOKENS)

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    @classmethod
    def get_allowed_cpu_memory_node_list(
            cls) -> tuple[list[int], list[LogicalCPUInfo]]:
        assert platform.system() == "Linux"

        # Init LogicalCPUInfo from lscpu
        lscpu_output = subprocess.check_output("lscpu -J -e=CPU,CORE,NODE",
                                               shell=True,
                                               text=True)
        logical_cpu_list: list[LogicalCPUInfo] = json.loads(
            lscpu_output, object_hook=LogicalCPUInfo.json_decoder)['cpus']

        # Filter CPUs with invalid attributes
        logical_cpu_list = [
            x for x in logical_cpu_list
            if -1 not in (x.id, x.physical_core, x.numa_node)
        ]

        # Filter allowed CPUs
        allowed_cpu_id_list = os.sched_getaffinity(0)
        logical_cpu_list = [
            x for x in logical_cpu_list if x.id in allowed_cpu_id_list
        ]

        # Get allowed NUMA nodes
        allowed_numa_nodes = set()
        for x in logical_cpu_list:
            allowed_numa_nodes.add(x.numa_node)  # type: ignore
        allowed_numa_nodes_list = sorted(allowed_numa_nodes)

        return allowed_numa_nodes_list, logical_cpu_list

307
308
309
310
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        logger.warning("Pin memory is not supported on CPU.")
        return False
311
312
313
314

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
315
316
317
318
319
320
321

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
        return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"  # noqa
322
323
324
325
326
327
328
329
330
331
332

    @classmethod
    def supports_structured_output(cls) -> bool:
        return True

    @classmethod
    def supports_v1(cls, model_config) -> bool:
        """Returns whether the current platform can support v1 for the supplied
        model configuration.
        """
        return True
333
334
335
336
337
338

    @classmethod
    def default_v1(cls, model_config) -> bool:
        """Returns whether the current platform can use v1 by default for the
        supplied model configuration.
        """
339
340
341
        arch = cls.get_cpu_architecture()
        return (cls.supports_v1(model_config) and arch
                in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM))