cpu.py 21.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import glob
5
import os
6
import platform
7
import subprocess
8
import sys
9
from dataclasses import dataclass
10
from typing import TYPE_CHECKING
11

12
import psutil
13
14
import torch

15
from vllm import envs
16
from vllm.logger import init_logger
17
from vllm.utils.ompmultiprocessing import OMPProcessManager
18
from vllm.utils.torch_utils import is_quantized_kv_cache
19
from vllm.v1.attention.backends.registry import AttentionBackendEnum
20

21
from .interface import CpuArchEnum, Platform, PlatformEnum
22
23

logger = init_logger(__name__)
24

25
26
if TYPE_CHECKING:
    from vllm.config import VllmConfig
27
    from vllm.v1.attention.selector import AttentionSelectorConfig
28
29
30
else:
    VllmConfig = None

31

32
def get_max_threads(pid=0):
33
    if hasattr(os, "sched_getaffinity"):
34
        return len(os.sched_getaffinity(pid))
35
    elif platform.system() == "Darwin":
36
37
38
39
40
        return os.cpu_count()
    else:
        raise NotImplementedError("Unsupported OS")


41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
@dataclass
class LogicalCPUInfo:
    id: int = -1
    physical_core: int = -1
    numa_node: int = -1

    @classmethod
    def _int(cls, value: str) -> int:
        try:
            int_value = int(value)
        except Exception:
            int_value = -1
        return int_value

    @staticmethod
    def json_decoder(obj_dict: dict):
        id = obj_dict.get("cpu")
        physical_core = obj_dict.get("core")
        numa_node = obj_dict.get("node")

        if not (id is None or physical_core is None or numa_node is None):
            return LogicalCPUInfo(
                id=LogicalCPUInfo._int(id),
                physical_core=LogicalCPUInfo._int(physical_core),
65
66
                numa_node=LogicalCPUInfo._int(numa_node),
            )
67
68
69
70
        else:
            return obj_dict


71
72
class CpuPlatform(Platform):
    _enum = PlatformEnum.CPU
73
    device_name: str = "cpu"
74
    device_type: str = "cpu"
75
    dispatch_key: str = "CPU"
76
    dist_backend: str = "gloo"
77
    device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
78
    omp_process_manager = None
79
80
81
    # Simultaneous Multithreading (SMT) level for OpenMP:
    # 4 on PowerPC, 1 on non-PowerPC architectures
    smt = 1
82
83
    global_cpu_mask = None
    simulate_numa = int(os.environ.get("_SIM_MULTI_NUMA", 0))
84

85
    @property
86
    def supported_dtypes(self) -> list[torch.dtype]:
87
88
        if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
            return [torch.bfloat16, torch.float32]
89
90
91
92
93
94
95
96
97
        elif self.get_cpu_architecture() == CpuArchEnum.ARM and sys.platform.startswith(
            "darwin"
        ):
            if (
                subprocess.check_output(
                    ["sysctl -n hw.optional.arm.FEAT_BF16"], shell=True
                ).strip()
                == b"1"
            ):
98
                return [torch.bfloat16, torch.float16, torch.float32]
99
            return [torch.float16, torch.float32]
100
        elif self.get_cpu_architecture() == CpuArchEnum.RISCV:
101
            return [torch.bfloat16, torch.float16, torch.float32]
102
103
104
        # x86/aarch64 CPU has supported both bf16 and fp16 natively.
        return [torch.bfloat16, torch.float16, torch.float32]

105
106
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
107
108
        return "cpu"

109
    @classmethod
110
111
    def get_attn_backend_cls(
        cls,
112
        selected_backend: "AttentionBackendEnum",
113
        attn_selector_config: "AttentionSelectorConfig",
114
        num_heads: int | None = None,
115
    ) -> str:
116
        if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
117
            logger.info("Cannot use %s backend on CPU.", selected_backend)
118
        if attn_selector_config.use_mla:
119
            raise NotImplementedError("MLA is not supported on CPU.")
120
        if attn_selector_config.use_sparse:
121
            raise NotImplementedError("Sparse Attention is not supported on CPU.")
122
        return AttentionBackendEnum.CPU_ATTN.get_path()
123

124
125
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
126
        from vllm.utils.mem_constants import GiB_bytes
127
        from vllm.utils.mem_utils import format_gib
128
129

        kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
130
        node_dir = "/sys/devices/system/node"
131
        if kv_cache_space is None:
132
133
134
135
136
137
138
139
140
            nodes = (
                [d for d in os.listdir(node_dir) if d.startswith("node")]
                if os.path.exists(node_dir)
                else []
            )
            num_numa_nodes = len(nodes) or 1
            free_cpu_memory = psutil.virtual_memory().total // num_numa_nodes
            DEFAULT_CPU_MEM_UTILIZATION = 0.5
            kv_cache_space = int(free_cpu_memory * DEFAULT_CPU_MEM_UTILIZATION)
141
            logger.warning_once(
142
143
                "VLLM_CPU_KVCACHE_SPACE not set. Using %s GiB for KV cache.",
                format_gib(kv_cache_space),
144
            )
145
146
147
148
        else:
            kv_cache_space *= GiB_bytes

        return kv_cache_space
149

150
151
152
153
154
155
156
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cpu.set_device(device)

157
158
159
160
    @classmethod
    def manual_seed_all(cls, seed: int) -> None:
        pass

161
162
    @classmethod
    def inference_mode(cls):
163
        return torch.no_grad()
164
165
166
167
168

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        model_config = vllm_config.model_config

169
170
        if model_config is not None:
            model_config.disable_cascade_attn = True
171

172
173
        cache_config = vllm_config.cache_config

174
        if not cache_config.user_specified_block_size:
175
            cache_config.block_size = 128
176

177
178
179
180
        if cache_config.block_size % 32 != 0:
            logger.warning(
                "CPU backend prefers block_size is multiples of 32, "
                "otherwise the performance is not optimized."
181
            )
182

183
        scheduler_config = vllm_config.scheduler_config
184
185
        # async scheduling is not required on CPU
        scheduler_config.async_scheduling = False
186
        if (
187
            scheduler_config.enable_chunked_prefill
188
            or cache_config.enable_prefix_caching
189
        ) and is_quantized_kv_cache(cache_config.cache_dtype):
190
191
192
193
            raise RuntimeError(
                "Chunked-prefill and prefix-cache on the CPU "
                "backend is not compatible with FP8 KV cache."
            )
194

195
        if is_quantized_kv_cache(cache_config.cache_dtype):
196
            logger.warning(
197
                "CPU backend doesn't support KV cache quantization fallback to auto."
198
            )
199
            cache_config.cache_dtype = "auto"
200

201
        cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()
202
203

        parallel_config = vllm_config.parallel_config
204
205
206
        # OMP requires the MP executor to function correctly, UniProc is not
        # supported as it is not possible to set the OMP environment correctly
        if parallel_config.distributed_executor_backend == "uni":
207
            parallel_config.distributed_executor_backend = "mp"
208
        if parallel_config.worker_cls == "auto":
209
            parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
210
211
        # Disable DBO
        if parallel_config.enable_dbo:
212
            logger.warning("Dual-Batch Overlap is not supported on CPU, disabled.")
213
            parallel_config.enable_dbo = False
214
215

        # Note: workaround for v1 gpu_model_runner
216
        from vllm.config import CompilationMode
217

218
219
220
        vllm_config.compilation_config.cudagraph_capture_sizes = []

        compilation_config = vllm_config.compilation_config
221
        if vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
222
223
224
225
226
227
228
229
230
231
232
233
            # Note: vLLM V1 is using PIECEWISE level compilation, which will
            # take time to compile kernels just-in-time with the inductor
            # backend. For CPU CI tests, most of them are executed fast and
            # compilations consume too much time, even with torch compile
            # cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
            # and just execute model with dynamo + eager mode to save time.
            # VLLM_CPU_CI_ENV is only used as an internal variable.
            if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0":
                backend = "eager"
            else:
                backend = "inductor"

234
            compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
235
            compilation_config.backend = backend
236
237
238
239
240
241
            compilation_config.inductor_compile_config.update(
                {
                    "dce": True,
                    "size_asserts": False,
                    "nan_asserts": False,
                    "epilogue_fusion": True,
242
                    "cpp.dynamic_threads": True,
243
244
                }
            )
245
            compilation_config.ir_enable_torch_wrap = False
246
247

        if vllm_config.lora_config is not None:
248
            compilation_config.mode = CompilationMode.NONE
249

250
251
252
253
254
255
256
        if (
            cls.get_cpu_architecture() == CpuArchEnum.ARM
            and "+gelu" not in compilation_config.custom_ops
            and "-gelu" not in compilation_config.custom_ops
        ):
            compilation_config.custom_ops.append("+gelu")

257
258
        vllm_config.profiler_config.torch_profiler_dump_cuda_time_total = False

259
260
261
262
263
264
        assert vllm_config.device_config.device_type == "cpu"

        #
        # Environment variables for CPU executor
        #

265
266
267
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        # Note: to avoid the error 'nthreads cannot be larger than environment
268
        # variable "NUMEXPR_MAX_THREADS" (64)'.
269
        os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads())
270

271
272
273
        # Disable torch async compiling which won't work with daemonic processes
        os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

274
        # Disable multi-stream for shared experts as no Stream on CPU
275
        os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
276

277
278
279
        # Avoid inductor generates num_thread() and breaks the thread binding
        os.environ["TORCHINDUCTOR_CPP_DYNAMIC_THREADS"] = "1"

280
        ld_preload_str = os.getenv("LD_PRELOAD", "")
281

282
283
        # Intel and CLANG OpenMP setting
        if "libiomp5.so" in ld_preload_str or "libomp5" in ld_preload_str:
284
285
            # The time(milliseconds) that a thread should wait after
            # completing the execution of a parallel region, before sleeping.
286
            os.environ["KMP_BLOCKTIME"] = "1"
287
            # Prevents the CPU to run into low performance state
288
            os.environ["KMP_TPAUSE"] = "0"
289
            # Provides fine granularity parallelism
290
291
292
            os.environ["KMP_FORKJOIN_BARRIER_PATTERN"] = "dist,dist"
            os.environ["KMP_PLAIN_BARRIER_PATTERN"] = "dist,dist"
            os.environ["KMP_REDUCTION_BARRIER_PATTERN"] = "dist,dist"
293

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        cpu_architecture = Platform.get_cpu_architecture()

        # LD_PRELOAD libtcmalloc, bundled under vllm/libs to reduce
        # memory allocation overhead
        if (
            platform.system() == "Linux"
            and cpu_architecture in (CpuArchEnum.ARM, CpuArchEnum.X86)
            and "libtcmalloc" not in ld_preload_str
        ):
            vllm_pkg = os.path.dirname(os.path.dirname(__file__))
            tcmalloc_so = None
            for pattern in ("libtcmalloc_minimal*.so*", "libtcmalloc.so*"):
                tcmalloc_so_candidates = glob.glob(
                    os.path.join(vllm_pkg, "libs", pattern)
                )
                if tcmalloc_so_candidates:
                    tcmalloc_so = tcmalloc_so_candidates[0]
                    break

            if tcmalloc_so is not None:
                if ld_preload_str:
                    ld_preload_str = f"{tcmalloc_so}:{ld_preload_str}"
                else:
                    ld_preload_str = tcmalloc_so
                os.environ["LD_PRELOAD"] = ld_preload_str

320
        os.environ["LOCAL_WORLD_SIZE"] = str(
321
322
            vllm_config.parallel_config.tensor_parallel_size
        )
323

324
        if model_config is not None and model_config.use_mla:
325
326
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
327
328
                "prefill and prefix caching to be disabled."
            )
329
330
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
331
                vllm_config.model_config.max_model_len,
332
                vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
333
            )
334
335
336
337
338
339
340
        # CI specific "quick" NUMA simulation - split all available CPUs
        # into a fake NUMA topology
        if os.environ.get("VLLM_CPU_SIM_MULTI_NUMA", None) is not None:
            os.environ["_SIM_MULTI_NUMA"] = str(
                vllm_config.parallel_config.world_size
                * vllm_config.parallel_config._api_process_count
            )
341

342
343
344
345
346
347
    @classmethod
    def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
        # TODO: CPU still sets block_size in check_and_update_config.
        # Move that logic here so block_size is chosen by the backend.
        pass

348
    @classmethod
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    def get_omp_manager(cls) -> OMPProcessManager:
        # initialise the OMP resource management if need be and return the manager
        if cls.omp_process_manager is None:
            if cls.get_cpu_architecture() == CpuArchEnum.POWERPC:
                cls.smt = 4
            cls.omp_process_manager = OMPProcessManager(
                affinity=cls.get_global_cpu_mask(), smt=cls.smt
            )
            # we need to fix up the topology returned by the OMP Manager for
            # simulated NUMA environments in CI
            if cls.simulate_numa > 0:
                logger.info(
                    "Adjusting numa topology to resemble at least %d nodes",
                    int(cls.simulate_numa),
                )
                om = cls.omp_process_manager
                while len(om.omp_places) < cls.simulate_numa:
                    new_omp_places = []
                    touched = False
                    for omp_place in om.omp_places:
                        if len(omp_place["mask"]) > 1:
                            touched = True
                            cpu_list = sorted(list(omp_place["mask"]))
                            new_omp_places.append(
                                {
                                    "mask": set(cpu_list[0 : int(len(cpu_list) / 2)]),
                                    "available": True,
                                }
                            )
                            new_omp_places.append(
                                {
                                    "mask": set(cpu_list[int(len(cpu_list) / 2) :]),
                                    "available": True,
                                }
                            )
                    if touched:
                        om.omp_places = new_omp_places
                    else:
                        raise ValueError(
                            "Cannot split the existing NUMA topology to match "
                            "simulation requirements"
                        )

        return cls.omp_process_manager

    @classmethod
    def get_global_cpu_mask(cls) -> set[int]:
        # get global cpu mask
        if cls.global_cpu_mask is None:
398
399
400
401
402
403
            if hasattr(os, "sched_getaffinity"):
                cls.global_cpu_mask = os.sched_getaffinity(0)
            else:
                # macOS does not support sched_getaffinity
                cpu_count = os.cpu_count() or 1
                cls.global_cpu_mask = set(range(cpu_count))
404
        return cls.global_cpu_mask
405

406
407
408
409
410
411
412
413
414
415
416
    @classmethod
    def reserve_cpus(cls, reserve: set[int]) -> bool:
        # remove CPUs from global mask, for now there is no "release" mechanism
        if cls.omp_process_manager is not None:
            for place in cls.omp_process_manager.omp_places:
                if not place["available"]:
                    return False
        cls.global_cpu_mask = cls.get_global_cpu_mask() - reserve
        # reinitialize OMP resource management
        cls.omp_process_manager = OMPProcessManager(
            affinity=cls.global_cpu_mask, smt=cls.smt
417
        )
418
        return True
419

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    @classmethod
    def discover_numa_topology(cls) -> list[list[int]]:
        """
        Discover NUMA topology and keep the last physical core of each numa
        into one core group list for nixl start_kv_load()
        """
        SYS_NODE = "/sys/devices/system/node"
        SYS_CPU = "/sys/devices/system/cpu"

        if not (os.path.exists(SYS_NODE) and os.path.exists(SYS_CPU)):
            return []

        core_rsv_for_kv = []
        for node in os.listdir(SYS_NODE):
            if not node.startswith("node") or not node[4:].isdigit():
                continue
            node_path = f"{SYS_NODE}/{node}"

            seen_phys = set()
            for cpu in os.listdir(node_path):
                if not cpu.startswith("cpu") or not cpu[3:].isdigit():
                    continue

                cpu_id = int(cpu[3:])
                # thread_siblings based on cpu_id
                path = f"{SYS_CPU}/cpu{cpu_id}/topology/thread_siblings_list"

                if os.path.exists(path):
                    try:
                        with open(path) as f:
                            s = f.read()
                        cpus: list[int] = []
                        for part in s.strip().split(","):
                            if "-" in part:
                                a, b = map(int, part.split("-"))
                                cpus.extend(range(a, b + 1))
                            else:
                                cpus.append(int(part))
                        siblings = cpus if cpus else [cpu_id]
                    except (OSError, ValueError):
                        siblings = [cpu_id]
                else:
                    siblings = [cpu_id]

                phys = min(siblings)

                if phys not in seen_phys:
                    seen_phys.add(phys)

            if len(seen_phys) > 0:
                core_rsv_for_kv.append(list(seen_phys))

        return core_rsv_for_kv

474
475
476
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        return False
477
478
479
480

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
481
482
483
484
485
486
487

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
        return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"  # noqa
488
489
490
491
492

    @classmethod
    def supports_structured_output(cls) -> bool:
        return True

493
494
495
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
496
497
498
499

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
500
501
502
503

    @classmethod
    def import_kernels(cls) -> None:
        if Platform.get_cpu_architecture() in (CpuArchEnum.X86,):
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
            # Note: The lib name is _C_AVX2/AVX512, but the module name is _C.
            # This will cause a exception "dynamic module does define
            # module export function". But the library is imported
            # successfully. So ignore the exception for now, until we find
            # a solution.
            ignored_msg = "dynamic module does not define module export function"
            if torch.cpu._is_avx512_supported():
                if torch.cpu._is_avx512_bf16_supported():
                    try:
                        import vllm._C  # noqa: F401
                    except ImportError as e:
                        logger.warning("Failed to import from vllm._C: %r", e)
                else:
                    try:
                        import vllm._C_AVX512  # noqa: F401
                    except ImportError as e:
                        if ignored_msg not in e.msg:
                            logger.warning(
                                "Failed to import from vllm._C_AVX512: %r", e
                            )
524
525
526
527
            else:
                try:
                    import vllm._C_AVX2  # noqa: F401
                except ImportError as e:
528
529
                    if ignored_msg not in e.msg:
                        logger.warning("Failed to import from vllm._C_AVX2: %r", e)
530
531
532
533
534
        else:
            try:
                import vllm._C  # noqa: F401
            except ImportError as e:
                logger.warning("Failed to import from vllm._C: %r", e)
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574

    @classmethod
    def pack_kv_cache(
        cls,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        block_ids: list[int],
        indices: torch.Tensor,
    ) -> None:
        """
        Rewrite the kv cache shape for the current platform.
        """
        # Import lazily: cpu_attn pulls in _custom_ops, which needs a fully
        # initialized vllm.platforms (avoid circular import while CpuPlatform loads).
        from vllm._custom_ops import cpu_attn_reshape_and_cache
        from vllm.v1.attention.backends.cpu_attn import _get_attn_isa

        dtype = key.dtype
        # For CPU_ATTN, the shape is [N, num_kv_heads, block_size, head_size]
        _, _, block_size, head_size = key_cache.shape
        key = key.permute(0, 2, 1, 3).flatten(0, 1)
        value = value.permute(0, 2, 1, 3).flatten(0, 1)

        isa = _get_attn_isa(dtype, block_size, head_size)
        block_offsets = torch.arange(block_size, device="cpu", dtype=torch.long)
        num_blocks = len(block_ids)
        slot_mapping = (
            block_offsets.reshape(1, block_size)
            + indices.reshape(num_blocks, 1) * block_size
        ).flatten()
        cpu_attn_reshape_and_cache(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            isa,
        )