"vscode:/vscode.git/clone" did not exist on "3123f151387d2afa49eaf3130bcee3556f2e87d2"
cpu.py 13.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
import os
6
import platform
7
import subprocess
8
import sys
9
from dataclasses import dataclass
10
from importlib.util import find_spec
11
from typing import TYPE_CHECKING, Optional
12

13
14
import torch

15
from vllm.logger import init_logger
16
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
17

18
from .interface import CpuArchEnum, Platform, PlatformEnum
19
20

logger = init_logger(__name__)
21

22
if TYPE_CHECKING:
23
    from vllm.attention.backends.registry import _Backend
24
25
    from vllm.config import VllmConfig
else:
26
    _Backend = None
27
28
    VllmConfig = None

29

30
31
32
33
34
35
36
37
38
def get_max_threads(pid=0):
    if hasattr(os, 'sched_getaffinity'):
        return len(os.sched_getaffinity(pid))
    elif platform.system() == 'Darwin':
        return os.cpu_count()
    else:
        raise NotImplementedError("Unsupported OS")


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@dataclass
class LogicalCPUInfo:
    id: int = -1
    physical_core: int = -1
    numa_node: int = -1

    @classmethod
    def _int(cls, value: str) -> int:
        try:
            int_value = int(value)
        except Exception:
            int_value = -1
        return int_value

    @staticmethod
    def json_decoder(obj_dict: dict):
        id = obj_dict.get("cpu")
        physical_core = obj_dict.get("core")
        numa_node = obj_dict.get("node")

        if not (id is None or physical_core is None or numa_node is None):
            return LogicalCPUInfo(
                id=LogicalCPUInfo._int(id),
                physical_core=LogicalCPUInfo._int(physical_core),
                numa_node=LogicalCPUInfo._int(numa_node))
        else:
            return obj_dict


68
69
class CpuPlatform(Platform):
    _enum = PlatformEnum.CPU
70
    device_name: str = "cpu"
71
    device_type: str = "cpu"
72
    dispatch_key: str = "CPU"
73
    dist_backend: str = "gloo"
74
    device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
75

76
    @property
77
    def supported_dtypes(self) -> list[torch.dtype]:
78
79
        if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
            return [torch.bfloat16, torch.float32]
80
81
82
83
84
85
        elif (self.get_cpu_architecture() == CpuArchEnum.ARM
              and sys.platform.startswith("darwin")):
            if (subprocess.check_output(
                ["sysctl -n hw.optional.arm.FEAT_BF16"],
                    shell=True).strip() == b"1"):
                return [torch.bfloat16, torch.float16, torch.float32]
86
            return [torch.float16, torch.float32]
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        elif self.get_cpu_architecture() == CpuArchEnum.RISCV:
            # Workaround for Issue #25655: RISC-V scheduler bug with float16
            #
            # Background:
            # - RISC-V currently uses scalar code path
            # - There is a latent bug in the vLLM scheduler that provides
            # invalid
            #   physical_block_idx values under certain conditions
            # - This bug causes segmentation faults when using float16
            # dtype on RISC-V
            # - Testing shows that forcing float32 successfully bypasses
            # this issue
            #
            # Technical details:
            # - The bug manifests as out-of-bounds physical_block_idx in
            # block_tables
            # - Only occurs on RISC-V hardware
            # tested on Sophgo SG2044
            # - Does not reproduce on x86 or other architectures
            # - Root cause is in Python-level scheduling logic,
            # not C++ kernels
            #
            # This is a temporary workaround until the scheduler bug is fixed.
            # See: https://github.com/vllm-project/vllm/issues/25655
            return [torch.float32]
112
113
114
        # x86/aarch64 CPU has supported both bf16 and fp16 natively.
        return [torch.bfloat16, torch.float16, torch.float32]

115
116
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
117
118
        return "cpu"

119
    @classmethod
120
    def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
121
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
122
                             block_size: int, use_v1: bool, use_mla: bool,
123
                             has_sink: bool, use_sparse: bool) -> str:
124
        from vllm.attention.backends.registry import _Backend
125
        if selected_backend and selected_backend != _Backend.TORCH_SDPA:
126
            logger.info("Cannot use %s backend on CPU.", selected_backend)
Thien Tran's avatar
Thien Tran committed
127
        if use_mla:
128
            raise NotImplementedError("MLA is not supported on CPU.")
129
130
131
        if use_sparse:
            raise NotImplementedError(
                "Sparse Attention is not supported on CPU.")
132
        logger.info("Using Torch SDPA backend.")
133
134
135
        if not use_v1:
            raise ValueError("CPU backend only supports V1.")
        return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
136

137
138
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
139
140
141
142
143
144
145
146
147
148
149
150
151
        import vllm.envs as envs
        from vllm.utils import GiB_bytes

        kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
        if kv_cache_space is None:
            kv_cache_space = 4 * GiB_bytes  # type: ignore
            logger.warning_once(
                "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
                "for CPU backend is not set, using 4 by default.")
        else:
            kv_cache_space *= GiB_bytes

        return kv_cache_space
152

153
154
155
156
157
158
159
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cpu.set_device(device)

160
161
    @classmethod
    def inference_mode(cls):
162
        return torch.no_grad()
163
164
165
166
167

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        model_config = vllm_config.model_config

168
169
        if model_config is not None:
            model_config.disable_cascade_attn = True
170

171
172
        cache_config = vllm_config.cache_config

173
        ipex_available = find_spec("intel_extension_for_pytorch") is not None
174

175
        if cache_config and cache_config.block_size is None:
176
            cache_config.block_size = 128 if ipex_available else 16
177

178
        if not ipex_available and cache_config.block_size != 16:
179
180
181
            raise RuntimeError(
                f"--block-size={cache_config.block_size} requires"
                " intel_extension_for_pytorch")
182

183
184
185
186
187
188
189
190
191
192
193
194
195
        scheduler_config = vllm_config.scheduler_config
        if ((scheduler_config.chunked_prefill_enabled
             or cache_config.enable_prefix_caching)
                and cache_config.cache_dtype != "auto"):
            raise RuntimeError("Chunked-prefill and prefix-cache on the CPU "
                               "backend is not compatible with FP8 KV cache.")

        if cache_config.cache_dtype == "fp8_e4m3":
            cache_config.cache_dtype = "fp8_e5m2"
            logger.warning(
                "CPU backend doesn't support fp8_e4m3 KV cache type, "
                "cast to fp8_e5m2.")

196
        if (cache_config.cache_dtype != "auto" and model_config is not None
197
198
199
200
201
                and model_config.dtype == torch.half):
            logger.warning("FP8 KV cache on the CPU backend only does not"
                           " support fp16 for now, cast to bf16.")
            model_config.dtype = torch.bfloat16

202
203
        cache_config.cpu_kvcache_space_bytes = \
            CpuPlatform.get_device_total_memory()
204
205

        parallel_config = vllm_config.parallel_config
206
207
        if (parallel_config.world_size > 1
                and parallel_config.distributed_executor_backend is not None
208
209
210
211
212
                and parallel_config.distributed_executor_backend != "mp"):
            logger.warning(("%s is not supported on CPU, fallback to mp "
                            "distributed executor backend."),
                           parallel_config.distributed_executor_backend)
            parallel_config.distributed_executor_backend = "mp"
213
        if parallel_config.worker_cls == "auto":
214
            parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
215
216
217
218
219
        # Disable DBO
        if parallel_config.enable_dbo:
            logger.warning(
                "Dual-Batch Overlap is not supported on CPU, disabled.")
            parallel_config.enable_dbo = False
220
221
222
223
224
225

        # Note: workaround for v1 gpu_model_runner
        from vllm.config import CompilationLevel
        vllm_config.compilation_config.cudagraph_capture_sizes = []

        compilation_config = vllm_config.compilation_config
226
        if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
227
228
229
230
231
232
233
234
235
236
237
238
239

            # Note: vLLM V1 is using PIECEWISE level compilation, which will
            # take time to compile kernels just-in-time with the inductor
            # backend. For CPU CI tests, most of them are executed fast and
            # compilations consume too much time, even with torch compile
            # cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
            # and just execute model with dynamo + eager mode to save time.
            # VLLM_CPU_CI_ENV is only used as an internal variable.
            if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0":
                backend = "eager"
            else:
                backend = "inductor"

240
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
241
            compilation_config.backend = backend
242
243
244
245
246
247
248
249
250
251
            compilation_config.inductor_compile_config.update({
                "dce":
                True,
                "size_asserts":
                False,
                "nan_asserts":
                False,
                "epilogue_fusion":
                True,
            })
252
253
            if compilation_config.use_inductor:
                compilation_config.custom_ops = ["none"]
254
255
256

        if vllm_config.lora_config is not None:
            compilation_config.level = CompilationLevel.NO_COMPILATION
257

258
259
260
261
262
263
        assert vllm_config.device_config.device_type == "cpu"

        #
        # Environment variables for CPU executor
        #

264
265
266
267
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        # Note: to avoid the error 'nthreads cannot be larger than environment
        #  variable "NUMEXPR_MAX_THREADS" (64)'.
268
        os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads())
269

270
271
272
        # Set default threads num for OpenMP parallel
        os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
        # Disable torch async compiling which won't work with daemonic processes
        os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

        # Intel OpenMP setting
        ld_prealod_str = os.getenv("LD_PRELOAD", "")
        if "libiomp5.so" in ld_prealod_str:
            # The time(milliseconds) that a thread should wait after
            # completing the execution of a parallel region, before sleeping.
            os.environ['KMP_BLOCKTIME'] = "1"
            # Prevents the CPU to run into low performance state
            os.environ['KMP_TPAUSE'] = "0"
            # Provides fine granularity parallelism
            os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
            os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
            os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"

        # To hint IPEX uses shared memory based AllReduce
        os.environ["LOCAL_WORLD_SIZE"] = str(
            vllm_config.parallel_config.tensor_parallel_size)

293
        if model_config is not None and model_config.use_mla:
294
295
296
297
298
299
300
301
302
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
                "prefill and prefix caching to be disabled.")
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
                DEFAULT_MAX_NUM_BATCHED_TOKENS)

303
    @classmethod
304
    def get_allowed_cpu_core_node_list(
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
            cls) -> tuple[list[int], list[LogicalCPUInfo]]:
        assert platform.system() == "Linux"

        # Init LogicalCPUInfo from lscpu
        lscpu_output = subprocess.check_output("lscpu -J -e=CPU,CORE,NODE",
                                               shell=True,
                                               text=True)
        logical_cpu_list: list[LogicalCPUInfo] = json.loads(
            lscpu_output, object_hook=LogicalCPUInfo.json_decoder)['cpus']

        # Filter CPUs with invalid attributes
        logical_cpu_list = [
            x for x in logical_cpu_list
            if -1 not in (x.id, x.physical_core, x.numa_node)
        ]

        # Filter allowed CPUs
        allowed_cpu_id_list = os.sched_getaffinity(0)
        logical_cpu_list = [
            x for x in logical_cpu_list if x.id in allowed_cpu_id_list
        ]

        # Get allowed NUMA nodes
        allowed_numa_nodes = set()
        for x in logical_cpu_list:
            allowed_numa_nodes.add(x.numa_node)  # type: ignore
        allowed_numa_nodes_list = sorted(allowed_numa_nodes)

333
334
335
336
337
338
339
        env_key = CpuPlatform.device_control_env_var
        if (env_key in os.environ and os.environ[env_key] != ""):
            visible_nodes = [int(s) for s in os.environ[env_key].split(',')]
            allowed_numa_nodes_list = [
                x for x in visible_nodes if x in allowed_cpu_id_list
            ]

340
341
        return allowed_numa_nodes_list, logical_cpu_list

342
343
344
345
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        logger.warning("Pin memory is not supported on CPU.")
        return False
346
347
348
349

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
350
351
352
353
354
355
356

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
        return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"  # noqa
357
358
359
360
361

    @classmethod
    def supports_structured_output(cls) -> bool:
        return True

362
363
364
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
365
366
367
368

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True