cpu.py 14.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
import os
6
import platform
7
import subprocess
8
import sys
9
from dataclasses import dataclass
10
from importlib.util import find_spec
11
from typing import TYPE_CHECKING, Optional
12

13
14
import torch

15
from vllm.logger import init_logger
16
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
17

18
from .interface import CpuArchEnum, Platform, PlatformEnum
19
20

logger = init_logger(__name__)
21

22
if TYPE_CHECKING:
23
    from vllm.attention.backends.registry import _Backend
24
25
    from vllm.config import VllmConfig
else:
26
    _Backend = None
27
28
    VllmConfig = None

29

30
def get_max_threads(pid=0):
31
    if hasattr(os, "sched_getaffinity"):
32
        return len(os.sched_getaffinity(pid))
33
    elif platform.system() == "Darwin":
34
35
36
37
38
        return os.cpu_count()
    else:
        raise NotImplementedError("Unsupported OS")


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@dataclass
class LogicalCPUInfo:
    id: int = -1
    physical_core: int = -1
    numa_node: int = -1

    @classmethod
    def _int(cls, value: str) -> int:
        try:
            int_value = int(value)
        except Exception:
            int_value = -1
        return int_value

    @staticmethod
    def json_decoder(obj_dict: dict):
        id = obj_dict.get("cpu")
        physical_core = obj_dict.get("core")
        numa_node = obj_dict.get("node")

        if not (id is None or physical_core is None or numa_node is None):
            return LogicalCPUInfo(
                id=LogicalCPUInfo._int(id),
                physical_core=LogicalCPUInfo._int(physical_core),
63
64
                numa_node=LogicalCPUInfo._int(numa_node),
            )
65
66
67
68
        else:
            return obj_dict


69
70
class CpuPlatform(Platform):
    _enum = PlatformEnum.CPU
71
    device_name: str = "cpu"
72
    device_type: str = "cpu"
73
    dispatch_key: str = "CPU"
74
    dist_backend: str = "gloo"
75
    device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
76

77
    @property
78
    def supported_dtypes(self) -> list[torch.dtype]:
79
80
        if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
            return [torch.bfloat16, torch.float32]
81
82
83
84
85
86
87
88
89
        elif self.get_cpu_architecture() == CpuArchEnum.ARM and sys.platform.startswith(
            "darwin"
        ):
            if (
                subprocess.check_output(
                    ["sysctl -n hw.optional.arm.FEAT_BF16"], shell=True
                ).strip()
                == b"1"
            ):
90
                return [torch.bfloat16, torch.float16, torch.float32]
91
            return [torch.float16, torch.float32]
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        elif self.get_cpu_architecture() == CpuArchEnum.RISCV:
            # Workaround for Issue #25655: RISC-V scheduler bug with float16
            #
            # Background:
            # - RISC-V currently uses scalar code path
            # - There is a latent bug in the vLLM scheduler that provides
            # invalid
            #   physical_block_idx values under certain conditions
            # - This bug causes segmentation faults when using float16
            # dtype on RISC-V
            # - Testing shows that forcing float32 successfully bypasses
            # this issue
            #
            # Technical details:
            # - The bug manifests as out-of-bounds physical_block_idx in
            # block_tables
            # - Only occurs on RISC-V hardware
            # tested on Sophgo SG2044
            # - Does not reproduce on x86 or other architectures
            # - Root cause is in Python-level scheduling logic,
            # not C++ kernels
            #
            # This is a temporary workaround until the scheduler bug is fixed.
            # See: https://github.com/vllm-project/vllm/issues/25655
            return [torch.float32]
117
118
119
        # x86/aarch64 CPU has supported both bf16 and fp16 natively.
        return [torch.bfloat16, torch.float16, torch.float32]

120
121
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
122
123
        return "cpu"

124
    @classmethod
125
126
127
128
129
130
131
132
133
134
135
136
    def get_attn_backend_cls(
        cls,
        selected_backend: "_Backend",
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: Optional[str],
        block_size: int,
        use_v1: bool,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
    ) -> str:
137
        from vllm.attention.backends.registry import _Backend
138

139
        if selected_backend and selected_backend != _Backend.TORCH_SDPA:
140
            logger.info("Cannot use %s backend on CPU.", selected_backend)
Thien Tran's avatar
Thien Tran committed
141
        if use_mla:
142
            raise NotImplementedError("MLA is not supported on CPU.")
143
        if use_sparse:
144
            raise NotImplementedError("Sparse Attention is not supported on CPU.")
145
        logger.info("Using Torch SDPA backend.")
146
147
148
        if not use_v1:
            raise ValueError("CPU backend only supports V1.")
        return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
149

150
151
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
152
153
154
155
156
157
158
159
        import vllm.envs as envs
        from vllm.utils import GiB_bytes

        kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
        if kv_cache_space is None:
            kv_cache_space = 4 * GiB_bytes  # type: ignore
            logger.warning_once(
                "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
160
161
                "for CPU backend is not set, using 4 by default."
            )
162
163
164
165
        else:
            kv_cache_space *= GiB_bytes

        return kv_cache_space
166

167
168
169
170
171
172
173
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cpu.set_device(device)

174
175
    @classmethod
    def inference_mode(cls):
176
        return torch.no_grad()
177
178
179
180
181

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        model_config = vllm_config.model_config

182
183
        if model_config is not None:
            model_config.disable_cascade_attn = True
184

185
186
        cache_config = vllm_config.cache_config

187
        ipex_available = find_spec("intel_extension_for_pytorch") is not None
188

189
        if cache_config and cache_config.block_size is None:
190
            cache_config.block_size = 128 if ipex_available else 16
191

192
        if not ipex_available and cache_config.block_size != 16:
193
194
            raise RuntimeError(
                f"--block-size={cache_config.block_size} requires"
195
196
                " intel_extension_for_pytorch"
            )
197

198
        scheduler_config = vllm_config.scheduler_config
199
200
201
202
203
204
205
206
        if (
            scheduler_config.chunked_prefill_enabled
            or cache_config.enable_prefix_caching
        ) and cache_config.cache_dtype != "auto":
            raise RuntimeError(
                "Chunked-prefill and prefix-cache on the CPU "
                "backend is not compatible with FP8 KV cache."
            )
207
208
209
210

        if cache_config.cache_dtype == "fp8_e4m3":
            cache_config.cache_dtype = "fp8_e5m2"
            logger.warning(
211
212
213
214
215
216
217
218
219
220
221
222
                "CPU backend doesn't support fp8_e4m3 KV cache type, cast to fp8_e5m2."
            )

        if (
            cache_config.cache_dtype != "auto"
            and model_config is not None
            and model_config.dtype == torch.half
        ):
            logger.warning(
                "FP8 KV cache on the CPU backend only does not"
                " support fp16 for now, cast to bf16."
            )
223
224
            model_config.dtype = torch.bfloat16

225
        cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()
226
227

        parallel_config = vllm_config.parallel_config
228
229
230
231
232
233
234
235
236
237
238
239
        if (
            parallel_config.world_size > 1
            and parallel_config.distributed_executor_backend is not None
            and parallel_config.distributed_executor_backend != "mp"
        ):
            logger.warning(
                (
                    "%s is not supported on CPU, fallback to mp "
                    "distributed executor backend."
                ),
                parallel_config.distributed_executor_backend,
            )
240
            parallel_config.distributed_executor_backend = "mp"
241
        if parallel_config.worker_cls == "auto":
242
            parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
243
244
        # Disable DBO
        if parallel_config.enable_dbo:
245
            logger.warning("Dual-Batch Overlap is not supported on CPU, disabled.")
246
            parallel_config.enable_dbo = False
247
248
249

        # Note: workaround for v1 gpu_model_runner
        from vllm.config import CompilationLevel
250

251
252
253
        vllm_config.compilation_config.cudagraph_capture_sizes = []

        compilation_config = vllm_config.compilation_config
254
        if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
255
256
257
258
259
260
261
262
263
264
265
266
            # Note: vLLM V1 is using PIECEWISE level compilation, which will
            # take time to compile kernels just-in-time with the inductor
            # backend. For CPU CI tests, most of them are executed fast and
            # compilations consume too much time, even with torch compile
            # cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
            # and just execute model with dynamo + eager mode to save time.
            # VLLM_CPU_CI_ENV is only used as an internal variable.
            if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0":
                backend = "eager"
            else:
                backend = "inductor"

267
            compilation_config.level = CompilationLevel.DYNAMO_ONCE
268
            compilation_config.backend = backend
269
270
271
272
273
274
275
276
            compilation_config.inductor_compile_config.update(
                {
                    "dce": True,
                    "size_asserts": False,
                    "nan_asserts": False,
                    "epilogue_fusion": True,
                }
            )
277
278
            if compilation_config.use_inductor:
                compilation_config.custom_ops = ["none"]
279
280
281

        if vllm_config.lora_config is not None:
            compilation_config.level = CompilationLevel.NO_COMPILATION
282

283
284
285
286
287
288
        assert vllm_config.device_config.device_type == "cpu"

        #
        # Environment variables for CPU executor
        #

289
290
291
292
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        # Note: to avoid the error 'nthreads cannot be larger than environment
        #  variable "NUMEXPR_MAX_THREADS" (64)'.
293
        os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads())
294

295
296
297
        # Set default threads num for OpenMP parallel
        os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())

298
299
300
301
302
303
304
305
        # Disable torch async compiling which won't work with daemonic processes
        os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

        # Intel OpenMP setting
        ld_prealod_str = os.getenv("LD_PRELOAD", "")
        if "libiomp5.so" in ld_prealod_str:
            # The time(milliseconds) that a thread should wait after
            # completing the execution of a parallel region, before sleeping.
306
            os.environ["KMP_BLOCKTIME"] = "1"
307
            # Prevents the CPU to run into low performance state
308
            os.environ["KMP_TPAUSE"] = "0"
309
            # Provides fine granularity parallelism
310
311
312
            os.environ["KMP_FORKJOIN_BARRIER_PATTERN"] = "dist,dist"
            os.environ["KMP_PLAIN_BARRIER_PATTERN"] = "dist,dist"
            os.environ["KMP_REDUCTION_BARRIER_PATTERN"] = "dist,dist"
313
314
315

        # To hint IPEX uses shared memory based AllReduce
        os.environ["LOCAL_WORLD_SIZE"] = str(
316
317
            vllm_config.parallel_config.tensor_parallel_size
        )
318

319
        if model_config is not None and model_config.use_mla:
320
321
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
322
323
                "prefill and prefix caching to be disabled."
            )
324
325
326
327
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
328
329
                DEFAULT_MAX_NUM_BATCHED_TOKENS,
            )
330

331
    @classmethod
332
    def get_allowed_cpu_core_node_list(cls) -> tuple[list[int], list[LogicalCPUInfo]]:
333
334
335
        assert platform.system() == "Linux"

        # Init LogicalCPUInfo from lscpu
336
337
338
        lscpu_output = subprocess.check_output(
            "lscpu -J -e=CPU,CORE,NODE", shell=True, text=True
        )
339
        logical_cpu_list: list[LogicalCPUInfo] = json.loads(
340
341
            lscpu_output, object_hook=LogicalCPUInfo.json_decoder
        )["cpus"]
342
343
344

        # Filter CPUs with invalid attributes
        logical_cpu_list = [
345
346
            x
            for x in logical_cpu_list
347
348
349
350
            if -1 not in (x.id, x.physical_core, x.numa_node)
        ]

        # Filter allowed CPUs
351
352
353
354
355
        if hasattr(os, "sched_getaffinity"):
            allowed_cpu_id_list = os.sched_getaffinity(0)
        else:
            raise NotImplementedError("Unsupported OS")
        logical_cpu_list = [x for x in logical_cpu_list if x.id in allowed_cpu_id_list]
356
357
358
359
360
361
362

        # Get allowed NUMA nodes
        allowed_numa_nodes = set()
        for x in logical_cpu_list:
            allowed_numa_nodes.add(x.numa_node)  # type: ignore
        allowed_numa_nodes_list = sorted(allowed_numa_nodes)

363
        env_key = CpuPlatform.device_control_env_var
364
365
        if env_key in os.environ and os.environ[env_key] != "":
            visible_nodes = [int(s) for s in os.environ[env_key].split(",")]
366
367
368
369
            allowed_numa_nodes_list = [
                x for x in visible_nodes if x in allowed_cpu_id_list
            ]

370
371
        return allowed_numa_nodes_list, logical_cpu_list

372
373
374
375
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        logger.warning("Pin memory is not supported on CPU.")
        return False
376
377
378
379

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
380
381
382
383
384
385
386

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
        return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"  # noqa
387
388
389
390
391

    @classmethod
    def supports_structured_output(cls) -> bool:
        return True

392
393
394
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
395
396
397
398

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True