cpu.py 14.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import glob
5
import json
6
import os
7
import platform
8
import subprocess
9
import sys
10
from dataclasses import dataclass
11
from typing import TYPE_CHECKING
12

13
import regex as re
14
15
import torch

16
from vllm import envs
17
from vllm.attention.backends.registry import AttentionBackendEnum
18
19
from vllm.logger import init_logger

20
from .interface import CpuArchEnum, Platform, PlatformEnum
21
22

logger = init_logger(__name__)
23

24
25
26
27
28
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

29

30
def get_max_threads(pid=0):
31
    if hasattr(os, "sched_getaffinity"):
32
        return len(os.sched_getaffinity(pid))
33
    elif platform.system() == "Darwin":
34
35
36
37
38
        return os.cpu_count()
    else:
        raise NotImplementedError("Unsupported OS")


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@dataclass
class LogicalCPUInfo:
    id: int = -1
    physical_core: int = -1
    numa_node: int = -1

    @classmethod
    def _int(cls, value: str) -> int:
        try:
            int_value = int(value)
        except Exception:
            int_value = -1
        return int_value

    @staticmethod
    def json_decoder(obj_dict: dict):
        id = obj_dict.get("cpu")
        physical_core = obj_dict.get("core")
        numa_node = obj_dict.get("node")

        if not (id is None or physical_core is None or numa_node is None):
            return LogicalCPUInfo(
                id=LogicalCPUInfo._int(id),
                physical_core=LogicalCPUInfo._int(physical_core),
63
64
                numa_node=LogicalCPUInfo._int(numa_node),
            )
65
66
67
68
        else:
            return obj_dict


69
70
class CpuPlatform(Platform):
    _enum = PlatformEnum.CPU
71
    device_name: str = "cpu"
72
    device_type: str = "cpu"
73
    dispatch_key: str = "CPU"
74
    dist_backend: str = "gloo"
75
    device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
76

77
    @property
78
    def supported_dtypes(self) -> list[torch.dtype]:
79
80
        if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
            return [torch.bfloat16, torch.float32]
81
82
83
84
85
86
87
88
89
        elif self.get_cpu_architecture() == CpuArchEnum.ARM and sys.platform.startswith(
            "darwin"
        ):
            if (
                subprocess.check_output(
                    ["sysctl -n hw.optional.arm.FEAT_BF16"], shell=True
                ).strip()
                == b"1"
            ):
90
                return [torch.bfloat16, torch.float16, torch.float32]
91
            return [torch.float16, torch.float32]
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        elif self.get_cpu_architecture() == CpuArchEnum.RISCV:
            # Workaround for Issue #25655: RISC-V scheduler bug with float16
            #
            # Background:
            # - RISC-V currently uses scalar code path
            # - There is a latent bug in the vLLM scheduler that provides
            # invalid
            #   physical_block_idx values under certain conditions
            # - This bug causes segmentation faults when using float16
            # dtype on RISC-V
            # - Testing shows that forcing float32 successfully bypasses
            # this issue
            #
            # Technical details:
            # - The bug manifests as out-of-bounds physical_block_idx in
            # block_tables
            # - Only occurs on RISC-V hardware
            # tested on Sophgo SG2044
            # - Does not reproduce on x86 or other architectures
            # - Root cause is in Python-level scheduling logic,
            # not C++ kernels
            #
            # This is a temporary workaround until the scheduler bug is fixed.
            # See: https://github.com/vllm-project/vllm/issues/25655
            return [torch.float32]
117
118
119
        # x86/aarch64 CPU has supported both bf16 and fp16 natively.
        return [torch.bfloat16, torch.float16, torch.float32]

120
121
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
122
123
        return "cpu"

124
    @classmethod
125
126
    def get_attn_backend_cls(
        cls,
127
        selected_backend: "AttentionBackendEnum",
128
129
        head_size: int,
        dtype: torch.dtype,
130
        kv_cache_dtype: str | None,
131
132
133
134
        block_size: int,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
135
        attn_type: str | None = None,
136
    ) -> str:
137
        if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
138
            logger.info("Cannot use %s backend on CPU.", selected_backend)
Thien Tran's avatar
Thien Tran committed
139
        if use_mla:
140
            raise NotImplementedError("MLA is not supported on CPU.")
141
        if use_sparse:
142
            raise NotImplementedError("Sparse Attention is not supported on CPU.")
143
        return AttentionBackendEnum.CPU_ATTN.get_path()
144

145
146
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
147
        from vllm.utils.mem_constants import GiB_bytes
148
149
150
151
152
153

        kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
        if kv_cache_space is None:
            kv_cache_space = 4 * GiB_bytes  # type: ignore
            logger.warning_once(
                "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
154
155
                "for CPU backend is not set, using 4 by default."
            )
156
157
158
159
        else:
            kv_cache_space *= GiB_bytes

        return kv_cache_space
160

161
162
163
164
165
166
167
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cpu.set_device(device)

168
169
    @classmethod
    def inference_mode(cls):
170
        return torch.no_grad()
171
172
173
174
175

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        model_config = vllm_config.model_config

176
177
        if model_config is not None:
            model_config.disable_cascade_attn = True
178

179
180
        cache_config = vllm_config.cache_config

181
182
        if cache_config.block_size is None:
            cache_config.block_size = 128
183

184
185
186
187
        if cache_config.block_size % 32 != 0:
            logger.warning(
                "CPU backend prefers block_size is multiples of 32, "
                "otherwise the performance is not optimized."
188
            )
189

190
        scheduler_config = vllm_config.scheduler_config
191
        if (
192
            scheduler_config.enable_chunked_prefill
193
194
195
196
197
198
            or cache_config.enable_prefix_caching
        ) and cache_config.cache_dtype != "auto":
            raise RuntimeError(
                "Chunked-prefill and prefix-cache on the CPU "
                "backend is not compatible with FP8 KV cache."
            )
199

200
        if cache_config.cache_dtype != "auto":
201
            logger.warning(
202
                "CPU backend doesn't support KV cache quantization fallback to auto."
203
            )
204
            cache_config.cache_dtype = "auto"
205

206
        cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()
207
208

        parallel_config = vllm_config.parallel_config
209
210
211
212
213
214
215
216
217
218
219
220
        if (
            parallel_config.world_size > 1
            and parallel_config.distributed_executor_backend is not None
            and parallel_config.distributed_executor_backend != "mp"
        ):
            logger.warning(
                (
                    "%s is not supported on CPU, fallback to mp "
                    "distributed executor backend."
                ),
                parallel_config.distributed_executor_backend,
            )
221
            parallel_config.distributed_executor_backend = "mp"
222
        if parallel_config.worker_cls == "auto":
223
            parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
224
225
        # Disable DBO
        if parallel_config.enable_dbo:
226
            logger.warning("Dual-Batch Overlap is not supported on CPU, disabled.")
227
            parallel_config.enable_dbo = False
228
229

        # Note: workaround for v1 gpu_model_runner
230
        from vllm.config import CompilationMode
231

232
233
234
        vllm_config.compilation_config.cudagraph_capture_sizes = []

        compilation_config = vllm_config.compilation_config
235
        if vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
236
237
238
239
240
241
242
243
244
245
246
247
            # Note: vLLM V1 is using PIECEWISE level compilation, which will
            # take time to compile kernels just-in-time with the inductor
            # backend. For CPU CI tests, most of them are executed fast and
            # compilations consume too much time, even with torch compile
            # cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
            # and just execute model with dynamo + eager mode to save time.
            # VLLM_CPU_CI_ENV is only used as an internal variable.
            if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0":
                backend = "eager"
            else:
                backend = "inductor"

248
            compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
249
            compilation_config.backend = backend
250
251
252
253
254
255
256
257
            compilation_config.inductor_compile_config.update(
                {
                    "dce": True,
                    "size_asserts": False,
                    "nan_asserts": False,
                    "epilogue_fusion": True,
                }
            )
258
259

        if vllm_config.lora_config is not None:
260
            compilation_config.mode = CompilationMode.NONE
261

262
263
264
265
266
267
        assert vllm_config.device_config.device_type == "cpu"

        #
        # Environment variables for CPU executor
        #

268
269
270
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        # Note: to avoid the error 'nthreads cannot be larger than environment
271
        # variable "NUMEXPR_MAX_THREADS" (64)'.
272
        os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads())
273

274
275
276
277
278
279
280
        if envs.VLLM_CPU_OMP_THREADS_BIND != "nobind":
            # Set default threads num for OpenMP parallel
            os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())
        else:
            # In this case, setting the OpenMP configuration via
            # OMP_NUM_THREADS is up to the user.
            logger.info("Disabling binding processes to CPU cores...")
281

282
283
284
        # Disable torch async compiling which won't work with daemonic processes
        os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

285
        # Disable multi-stream for shared experts as no Stream on CPU
286
        os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
287

288
        # Intel OpenMP setting
289
290
        ld_preload_str = os.getenv("LD_PRELOAD", "")
        if "libiomp5.so" in ld_preload_str:
291
292
            # The time(milliseconds) that a thread should wait after
            # completing the execution of a parallel region, before sleeping.
293
            os.environ["KMP_BLOCKTIME"] = "1"
294
            # Prevents the CPU to run into low performance state
295
            os.environ["KMP_TPAUSE"] = "0"
296
            # Provides fine granularity parallelism
297
298
299
            os.environ["KMP_FORKJOIN_BARRIER_PATTERN"] = "dist,dist"
            os.environ["KMP_PLAIN_BARRIER_PATTERN"] = "dist,dist"
            os.environ["KMP_REDUCTION_BARRIER_PATTERN"] = "dist,dist"
300

301
302
        if (
            platform.system() == "Linux"
303
304
            and Platform.get_cpu_architecture()
            in (CpuArchEnum.ARM, CpuArchEnum.POWERPC)
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
            and not ("libomp" in ld_preload_str or "libgomp" in ld_preload_str)
        ):
            # We need to LD_PRELOAD PyTorch's libgomp, otherwise only
            # one core will be properly utilized when we thread-bind
            # See: https://github.com/vllm-project/vllm/issues/27369
            # TODO: Remove once:
            # https://github.com/pytorch/pytorch/issues/166087 is fixed

            # We need to find the location of PyTorch's libgomp
            torch_pkg = os.path.dirname(torch.__file__)
            site_root = os.path.dirname(torch_pkg)
            torch_libs = os.path.join(site_root, "torch.libs")
            pytorch_libgomp_so_candidates = glob.glob(
                os.path.join(torch_libs, "libgomp-*.so*")
            )
            if pytorch_libgomp_so_candidates:
                pytorch_libgomp_so = pytorch_libgomp_so_candidates[0]
                if ld_preload_str:
                    ld_preload_str += ":"
                ld_preload_str += pytorch_libgomp_so
                os.environ["LD_PRELOAD"] = ld_preload_str

327
328
        # To hint IPEX uses shared memory based AllReduce
        os.environ["LOCAL_WORLD_SIZE"] = str(
329
330
            vllm_config.parallel_config.tensor_parallel_size
        )
331

332
        if model_config is not None and model_config.use_mla:
333
334
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
335
336
                "prefill and prefix caching to be disabled."
            )
337
338
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
339
                vllm_config.model_config.max_model_len,
340
                vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
341
            )
342

343
    @classmethod
344
    def get_allowed_cpu_core_node_list(cls) -> tuple[list[int], list[LogicalCPUInfo]]:
345
346
347
        assert platform.system() == "Linux"

        # Init LogicalCPUInfo from lscpu
348
349
350
        lscpu_output = subprocess.check_output(
            "lscpu -J -e=CPU,CORE,NODE", shell=True, text=True
        )
351
        lscpu_output = re.sub(r'"node":\s*-\s*(,|\n)', r'"node": 0\1', lscpu_output)
352
        logical_cpu_list: list[LogicalCPUInfo] = json.loads(
353
354
            lscpu_output, object_hook=LogicalCPUInfo.json_decoder
        )["cpus"]
355
356
357

        # Filter CPUs with invalid attributes
        logical_cpu_list = [
358
359
            x
            for x in logical_cpu_list
360
361
362
363
            if -1 not in (x.id, x.physical_core, x.numa_node)
        ]

        # Filter allowed CPUs
364
365
366
367
368
        if hasattr(os, "sched_getaffinity"):
            allowed_cpu_id_list = os.sched_getaffinity(0)
        else:
            raise NotImplementedError("Unsupported OS")
        logical_cpu_list = [x for x in logical_cpu_list if x.id in allowed_cpu_id_list]
369
370
371
372
373
374
375

        # Get allowed NUMA nodes
        allowed_numa_nodes = set()
        for x in logical_cpu_list:
            allowed_numa_nodes.add(x.numa_node)  # type: ignore
        allowed_numa_nodes_list = sorted(allowed_numa_nodes)

376
        env_key = CpuPlatform.device_control_env_var
377
378
        if env_key in os.environ and os.environ[env_key] != "":
            visible_nodes = [int(s) for s in os.environ[env_key].split(",")]
379
380
381
382
            allowed_numa_nodes_list = [
                x for x in visible_nodes if x in allowed_cpu_id_list
            ]

383
384
        return allowed_numa_nodes_list, logical_cpu_list

385
386
387
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        return False
388
389
390
391

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
392
393
394
395
396
397
398

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
        return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"  # noqa
399
400
401
402
403

    @classmethod
    def supports_structured_output(cls) -> bool:
        return True

404
405
406
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
407
408
409
410

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True