cpu.py 19.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import glob
5
import os
6
import platform
7
import subprocess
8
import sys
9
from dataclasses import dataclass
10
from typing import TYPE_CHECKING
11

12
import psutil
13
14
import torch

15
from vllm import envs
16
from vllm.logger import init_logger
17
from vllm.utils.ompmultiprocessing import OMPProcessManager
18
from vllm.utils.torch_utils import is_quantized_kv_cache
19
from vllm.v1.attention.backends.registry import AttentionBackendEnum
20

21
from .interface import CpuArchEnum, Platform, PlatformEnum
22
23

logger = init_logger(__name__)
24

25
26
if TYPE_CHECKING:
    from vllm.config import VllmConfig
27
    from vllm.v1.attention.selector import AttentionSelectorConfig
28
29
30
else:
    VllmConfig = None

31

32
def get_max_threads(pid=0):
33
    if hasattr(os, "sched_getaffinity"):
34
        return len(os.sched_getaffinity(pid))
35
    elif platform.system() == "Darwin":
36
37
38
39
40
        return os.cpu_count()
    else:
        raise NotImplementedError("Unsupported OS")


41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
@dataclass
class LogicalCPUInfo:
    id: int = -1
    physical_core: int = -1
    numa_node: int = -1

    @classmethod
    def _int(cls, value: str) -> int:
        try:
            int_value = int(value)
        except Exception:
            int_value = -1
        return int_value

    @staticmethod
    def json_decoder(obj_dict: dict):
        id = obj_dict.get("cpu")
        physical_core = obj_dict.get("core")
        numa_node = obj_dict.get("node")

        if not (id is None or physical_core is None or numa_node is None):
            return LogicalCPUInfo(
                id=LogicalCPUInfo._int(id),
                physical_core=LogicalCPUInfo._int(physical_core),
65
66
                numa_node=LogicalCPUInfo._int(numa_node),
            )
67
68
69
70
        else:
            return obj_dict


71
72
class CpuPlatform(Platform):
    _enum = PlatformEnum.CPU
73
    device_name: str = "cpu"
74
    device_type: str = "cpu"
75
    dispatch_key: str = "CPU"
76
    dist_backend: str = "gloo"
77
    device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
78
79
80
81
    omp_process_manager = None
    smt = 1  # SMT level for OMP - 4 threads on PowerPC, 1 on others
    global_cpu_mask = None
    simulate_numa = int(os.environ.get("_SIM_MULTI_NUMA", 0))
82

83
    @property
84
    def supported_dtypes(self) -> list[torch.dtype]:
85
86
        if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
            return [torch.bfloat16, torch.float32]
87
88
89
90
91
92
93
94
95
        elif self.get_cpu_architecture() == CpuArchEnum.ARM and sys.platform.startswith(
            "darwin"
        ):
            if (
                subprocess.check_output(
                    ["sysctl -n hw.optional.arm.FEAT_BF16"], shell=True
                ).strip()
                == b"1"
            ):
96
                return [torch.bfloat16, torch.float16, torch.float32]
97
            return [torch.float16, torch.float32]
98
        elif self.get_cpu_architecture() == CpuArchEnum.RISCV:
99
            return [torch.bfloat16, torch.float16, torch.float32]
100
101
102
        # x86/aarch64 CPU has supported both bf16 and fp16 natively.
        return [torch.bfloat16, torch.float16, torch.float32]

103
104
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
105
106
        return "cpu"

107
    @classmethod
108
109
    def get_attn_backend_cls(
        cls,
110
        selected_backend: "AttentionBackendEnum",
111
        attn_selector_config: "AttentionSelectorConfig",
112
        num_heads: int | None = None,
113
    ) -> str:
114
        if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
115
            logger.info("Cannot use %s backend on CPU.", selected_backend)
116
        if attn_selector_config.use_mla:
117
            raise NotImplementedError("MLA is not supported on CPU.")
118
        if attn_selector_config.use_sparse:
119
            raise NotImplementedError("Sparse Attention is not supported on CPU.")
120
        return AttentionBackendEnum.CPU_ATTN.get_path()
121

122
123
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
124
        from vllm.utils.mem_constants import GiB_bytes
125
        from vllm.utils.mem_utils import format_gib
126
127

        kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
128
        node_dir = "/sys/devices/system/node"
129
        if kv_cache_space is None:
130
131
132
133
134
135
136
137
138
            nodes = (
                [d for d in os.listdir(node_dir) if d.startswith("node")]
                if os.path.exists(node_dir)
                else []
            )
            num_numa_nodes = len(nodes) or 1
            free_cpu_memory = psutil.virtual_memory().total // num_numa_nodes
            DEFAULT_CPU_MEM_UTILIZATION = 0.5
            kv_cache_space = int(free_cpu_memory * DEFAULT_CPU_MEM_UTILIZATION)
139
            logger.warning_once(
140
141
                "VLLM_CPU_KVCACHE_SPACE not set. Using %s GiB for KV cache.",
                format_gib(kv_cache_space),
142
            )
143
144
145
146
        else:
            kv_cache_space *= GiB_bytes

        return kv_cache_space
147

148
149
150
151
152
153
154
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cpu.set_device(device)

155
156
    @classmethod
    def inference_mode(cls):
157
        return torch.no_grad()
158
159
160
161
162

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        model_config = vllm_config.model_config

163
164
        if model_config is not None:
            model_config.disable_cascade_attn = True
165

166
167
        cache_config = vllm_config.cache_config

168
        if not cache_config.user_specified_block_size:
169
            cache_config.block_size = 128
170

171
172
173
174
        if cache_config.block_size % 32 != 0:
            logger.warning(
                "CPU backend prefers block_size is multiples of 32, "
                "otherwise the performance is not optimized."
175
            )
176

177
        scheduler_config = vllm_config.scheduler_config
178
179
        # async scheduling is not required on CPU
        scheduler_config.async_scheduling = False
180
        if (
181
            scheduler_config.enable_chunked_prefill
182
            or cache_config.enable_prefix_caching
183
        ) and is_quantized_kv_cache(cache_config.cache_dtype):
184
185
186
187
            raise RuntimeError(
                "Chunked-prefill and prefix-cache on the CPU "
                "backend is not compatible with FP8 KV cache."
            )
188

189
        if is_quantized_kv_cache(cache_config.cache_dtype):
190
            logger.warning(
191
                "CPU backend doesn't support KV cache quantization fallback to auto."
192
            )
193
            cache_config.cache_dtype = "auto"
194

195
        cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()
196
197

        parallel_config = vllm_config.parallel_config
198
199
200
        # OMP requires the MP executor to function correctly, UniProc is not
        # supported as it is not possible to set the OMP environment correctly
        if parallel_config.distributed_executor_backend == "uni":
201
            parallel_config.distributed_executor_backend = "mp"
202
        if parallel_config.worker_cls == "auto":
203
            parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
204
205
        # Disable DBO
        if parallel_config.enable_dbo:
206
            logger.warning("Dual-Batch Overlap is not supported on CPU, disabled.")
207
            parallel_config.enable_dbo = False
208
209

        # Note: workaround for v1 gpu_model_runner
210
        from vllm.config import CompilationMode
211

212
213
214
        vllm_config.compilation_config.cudagraph_capture_sizes = []

        compilation_config = vllm_config.compilation_config
215
        if vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
216
217
218
219
220
221
222
223
224
225
226
227
            # Note: vLLM V1 is using PIECEWISE level compilation, which will
            # take time to compile kernels just-in-time with the inductor
            # backend. For CPU CI tests, most of them are executed fast and
            # compilations consume too much time, even with torch compile
            # cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
            # and just execute model with dynamo + eager mode to save time.
            # VLLM_CPU_CI_ENV is only used as an internal variable.
            if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0":
                backend = "eager"
            else:
                backend = "inductor"

228
            compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
229
            compilation_config.backend = backend
230
231
232
233
234
235
            compilation_config.inductor_compile_config.update(
                {
                    "dce": True,
                    "size_asserts": False,
                    "nan_asserts": False,
                    "epilogue_fusion": True,
236
                    "cpp.dynamic_threads": True,
237
238
                }
            )
239
240

        if vllm_config.lora_config is not None:
241
            compilation_config.mode = CompilationMode.NONE
242

243
244
        vllm_config.profiler_config.torch_profiler_dump_cuda_time_total = False

245
246
247
248
249
250
        assert vllm_config.device_config.device_type == "cpu"

        #
        # Environment variables for CPU executor
        #

251
252
253
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        # Note: to avoid the error 'nthreads cannot be larger than environment
254
        # variable "NUMEXPR_MAX_THREADS" (64)'.
255
        os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads())
256

257
258
259
        # Disable torch async compiling which won't work with daemonic processes
        os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

260
        # Disable multi-stream for shared experts as no Stream on CPU
261
        os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
262

263
264
265
        # Avoid inductor generates num_thread() and breaks the thread binding
        os.environ["TORCHINDUCTOR_CPP_DYNAMIC_THREADS"] = "1"

266
        ld_preload_str = os.getenv("LD_PRELOAD", "")
267

268
269
        # Intel and CLANG OpenMP setting
        if "libiomp5.so" in ld_preload_str or "libomp5" in ld_preload_str:
270
271
            # The time(milliseconds) that a thread should wait after
            # completing the execution of a parallel region, before sleeping.
272
            os.environ["KMP_BLOCKTIME"] = "1"
273
            # Prevents the CPU to run into low performance state
274
            os.environ["KMP_TPAUSE"] = "0"
275
            # Provides fine granularity parallelism
276
277
278
            os.environ["KMP_FORKJOIN_BARRIER_PATTERN"] = "dist,dist"
            os.environ["KMP_PLAIN_BARRIER_PATTERN"] = "dist,dist"
            os.environ["KMP_REDUCTION_BARRIER_PATTERN"] = "dist,dist"
279

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        cpu_architecture = Platform.get_cpu_architecture()

        # LD_PRELOAD libtcmalloc, bundled under vllm/libs to reduce
        # memory allocation overhead
        if (
            platform.system() == "Linux"
            and cpu_architecture in (CpuArchEnum.ARM, CpuArchEnum.X86)
            and "libtcmalloc" not in ld_preload_str
        ):
            vllm_pkg = os.path.dirname(os.path.dirname(__file__))
            tcmalloc_so = None
            for pattern in ("libtcmalloc_minimal*.so*", "libtcmalloc.so*"):
                tcmalloc_so_candidates = glob.glob(
                    os.path.join(vllm_pkg, "libs", pattern)
                )
                if tcmalloc_so_candidates:
                    tcmalloc_so = tcmalloc_so_candidates[0]
                    break

            if tcmalloc_so is not None:
                if ld_preload_str:
                    ld_preload_str = f"{tcmalloc_so}:{ld_preload_str}"
                else:
                    ld_preload_str = tcmalloc_so
                os.environ["LD_PRELOAD"] = ld_preload_str

306
        os.environ["LOCAL_WORLD_SIZE"] = str(
307
308
            vllm_config.parallel_config.tensor_parallel_size
        )
309

310
        if model_config is not None and model_config.use_mla:
311
312
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
313
314
                "prefill and prefix caching to be disabled."
            )
315
316
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
317
                vllm_config.model_config.max_model_len,
318
                vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
319
            )
320
321
322
323
324
325
326
        # CI specific "quick" NUMA simulation - split all available CPUs
        # into a fake NUMA topology
        if os.environ.get("VLLM_CPU_SIM_MULTI_NUMA", None) is not None:
            os.environ["_SIM_MULTI_NUMA"] = str(
                vllm_config.parallel_config.world_size
                * vllm_config.parallel_config._api_process_count
            )
327

328
329
330
331
332
333
    @classmethod
    def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
        # TODO: CPU still sets block_size in check_and_update_config.
        # Move that logic here so block_size is chosen by the backend.
        pass

334
    @classmethod
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    def get_omp_manager(cls) -> OMPProcessManager:
        # initialise the OMP resource management if need be and return the manager
        if cls.omp_process_manager is None:
            if cls.get_cpu_architecture() == CpuArchEnum.POWERPC:
                cls.smt = 4
            cls.omp_process_manager = OMPProcessManager(
                affinity=cls.get_global_cpu_mask(), smt=cls.smt
            )
            # we need to fix up the topology returned by the OMP Manager for
            # simulated NUMA environments in CI
            if cls.simulate_numa > 0:
                logger.info(
                    "Adjusting numa topology to resemble at least %d nodes",
                    int(cls.simulate_numa),
                )
                om = cls.omp_process_manager
                while len(om.omp_places) < cls.simulate_numa:
                    new_omp_places = []
                    touched = False
                    for omp_place in om.omp_places:
                        if len(omp_place["mask"]) > 1:
                            touched = True
                            cpu_list = sorted(list(omp_place["mask"]))
                            new_omp_places.append(
                                {
                                    "mask": set(cpu_list[0 : int(len(cpu_list) / 2)]),
                                    "available": True,
                                }
                            )
                            new_omp_places.append(
                                {
                                    "mask": set(cpu_list[int(len(cpu_list) / 2) :]),
                                    "available": True,
                                }
                            )
                    if touched:
                        om.omp_places = new_omp_places
                    else:
                        raise ValueError(
                            "Cannot split the existing NUMA topology to match "
                            "simulation requirements"
                        )

        return cls.omp_process_manager

    @classmethod
    def get_global_cpu_mask(cls) -> set[int]:
        # get global cpu mask
        if cls.global_cpu_mask is None:
            cls.global_cpu_mask = os.sched_getaffinity(0)
        return cls.global_cpu_mask
386

387
388
389
390
391
392
393
394
395
396
397
    @classmethod
    def reserve_cpus(cls, reserve: set[int]) -> bool:
        # remove CPUs from global mask, for now there is no "release" mechanism
        if cls.omp_process_manager is not None:
            for place in cls.omp_process_manager.omp_places:
                if not place["available"]:
                    return False
        cls.global_cpu_mask = cls.get_global_cpu_mask() - reserve
        # reinitialize OMP resource management
        cls.omp_process_manager = OMPProcessManager(
            affinity=cls.global_cpu_mask, smt=cls.smt
398
        )
399
        return True
400

401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
    @classmethod
    def discover_numa_topology(cls) -> list[list[int]]:
        """
        Discover NUMA topology and keep the last physical core of each numa
        into one core group list for nixl start_kv_load()
        """
        SYS_NODE = "/sys/devices/system/node"
        SYS_CPU = "/sys/devices/system/cpu"

        if not (os.path.exists(SYS_NODE) and os.path.exists(SYS_CPU)):
            return []

        core_rsv_for_kv = []
        for node in os.listdir(SYS_NODE):
            if not node.startswith("node") or not node[4:].isdigit():
                continue
            node_path = f"{SYS_NODE}/{node}"

            seen_phys = set()
            for cpu in os.listdir(node_path):
                if not cpu.startswith("cpu") or not cpu[3:].isdigit():
                    continue

                cpu_id = int(cpu[3:])
                # thread_siblings based on cpu_id
                path = f"{SYS_CPU}/cpu{cpu_id}/topology/thread_siblings_list"

                if os.path.exists(path):
                    try:
                        with open(path) as f:
                            s = f.read()
                        cpus: list[int] = []
                        for part in s.strip().split(","):
                            if "-" in part:
                                a, b = map(int, part.split("-"))
                                cpus.extend(range(a, b + 1))
                            else:
                                cpus.append(int(part))
                        siblings = cpus if cpus else [cpu_id]
                    except (OSError, ValueError):
                        siblings = [cpu_id]
                else:
                    siblings = [cpu_id]

                phys = min(siblings)

                if phys not in seen_phys:
                    seen_phys.add(phys)

            if len(seen_phys) > 0:
                core_rsv_for_kv.append(list(seen_phys))

        return core_rsv_for_kv

455
456
457
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        return False
458
459
460
461

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
462
463
464
465
466
467
468

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
        return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"  # noqa
469
470
471
472
473

    @classmethod
    def supports_structured_output(cls) -> bool:
        return True

474
475
476
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
477
478
479
480

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
481
482
483
484

    @classmethod
    def import_kernels(cls) -> None:
        if Platform.get_cpu_architecture() in (CpuArchEnum.X86,):
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
            # Note: The lib name is _C_AVX2/AVX512, but the module name is _C.
            # This will cause a exception "dynamic module does define
            # module export function". But the library is imported
            # successfully. So ignore the exception for now, until we find
            # a solution.
            ignored_msg = "dynamic module does not define module export function"
            if torch.cpu._is_avx512_supported():
                if torch.cpu._is_avx512_bf16_supported():
                    try:
                        import vllm._C  # noqa: F401
                    except ImportError as e:
                        logger.warning("Failed to import from vllm._C: %r", e)
                else:
                    try:
                        import vllm._C_AVX512  # noqa: F401
                    except ImportError as e:
                        if ignored_msg not in e.msg:
                            logger.warning(
                                "Failed to import from vllm._C_AVX512: %r", e
                            )
505
506
507
508
            else:
                try:
                    import vllm._C_AVX2  # noqa: F401
                except ImportError as e:
509
510
                    if ignored_msg not in e.msg:
                        logger.warning("Failed to import from vllm._C_AVX2: %r", e)
511
512
513
514
515
        else:
            try:
                import vllm._C  # noqa: F401
            except ImportError as e:
                logger.warning("Failed to import from vllm._C: %r", e)