cpu.py 18.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import glob
5
import json
6
import os
7
import platform
8
import subprocess
9
import sys
10
from dataclasses import dataclass
11
from typing import TYPE_CHECKING
12

13
import psutil
14
import regex as re
15
16
import torch

17
from vllm import envs
18
from vllm.logger import init_logger
19
from vllm.v1.attention.backend import is_quantized_kv_cache
20
from vllm.v1.attention.backends.registry import AttentionBackendEnum
21

22
from .interface import CpuArchEnum, Platform, PlatformEnum
23
24

logger = init_logger(__name__)
25

26
27
if TYPE_CHECKING:
    from vllm.config import VllmConfig
28
    from vllm.v1.attention.selector import AttentionSelectorConfig
29
30
31
else:
    VllmConfig = None

32

33
def get_max_threads(pid=0):
34
    if hasattr(os, "sched_getaffinity"):
35
        return len(os.sched_getaffinity(pid))
36
    elif platform.system() == "Darwin":
37
38
39
40
41
        return os.cpu_count()
    else:
        raise NotImplementedError("Unsupported OS")


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@dataclass
class LogicalCPUInfo:
    id: int = -1
    physical_core: int = -1
    numa_node: int = -1

    @classmethod
    def _int(cls, value: str) -> int:
        try:
            int_value = int(value)
        except Exception:
            int_value = -1
        return int_value

    @staticmethod
    def json_decoder(obj_dict: dict):
        id = obj_dict.get("cpu")
        physical_core = obj_dict.get("core")
        numa_node = obj_dict.get("node")

        if not (id is None or physical_core is None or numa_node is None):
            return LogicalCPUInfo(
                id=LogicalCPUInfo._int(id),
                physical_core=LogicalCPUInfo._int(physical_core),
66
67
                numa_node=LogicalCPUInfo._int(numa_node),
            )
68
69
70
71
        else:
            return obj_dict


72
73
class CpuPlatform(Platform):
    _enum = PlatformEnum.CPU
74
    device_name: str = "cpu"
75
    device_type: str = "cpu"
76
    dispatch_key: str = "CPU"
77
    dist_backend: str = "gloo"
78
    device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
79

80
    @property
81
    def supported_dtypes(self) -> list[torch.dtype]:
82
83
        if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
            return [torch.bfloat16, torch.float32]
84
85
86
87
88
89
90
91
92
        elif self.get_cpu_architecture() == CpuArchEnum.ARM and sys.platform.startswith(
            "darwin"
        ):
            if (
                subprocess.check_output(
                    ["sysctl -n hw.optional.arm.FEAT_BF16"], shell=True
                ).strip()
                == b"1"
            ):
93
                return [torch.bfloat16, torch.float16, torch.float32]
94
            return [torch.float16, torch.float32]
95
        elif self.get_cpu_architecture() == CpuArchEnum.RISCV:
96
            return [torch.bfloat16, torch.float16, torch.float32]
97
98
99
        # x86/aarch64 CPU has supported both bf16 and fp16 natively.
        return [torch.bfloat16, torch.float16, torch.float32]

100
101
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
102
103
        return "cpu"

104
    @classmethod
105
106
    def get_attn_backend_cls(
        cls,
107
        selected_backend: "AttentionBackendEnum",
108
        attn_selector_config: "AttentionSelectorConfig",
109
        num_heads: int | None = None,
110
    ) -> str:
111
        if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
112
            logger.info("Cannot use %s backend on CPU.", selected_backend)
113
        if attn_selector_config.use_mla:
114
            raise NotImplementedError("MLA is not supported on CPU.")
115
        if attn_selector_config.use_sparse:
116
            raise NotImplementedError("Sparse Attention is not supported on CPU.")
117
        return AttentionBackendEnum.CPU_ATTN.get_path()
118

119
120
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
121
        from vllm.utils.mem_constants import GiB_bytes
122
        from vllm.utils.mem_utils import format_gib
123
124

        kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
125
        node_dir = "/sys/devices/system/node"
126
        if kv_cache_space is None:
127
128
129
130
131
132
133
134
135
            nodes = (
                [d for d in os.listdir(node_dir) if d.startswith("node")]
                if os.path.exists(node_dir)
                else []
            )
            num_numa_nodes = len(nodes) or 1
            free_cpu_memory = psutil.virtual_memory().total // num_numa_nodes
            DEFAULT_CPU_MEM_UTILIZATION = 0.5
            kv_cache_space = int(free_cpu_memory * DEFAULT_CPU_MEM_UTILIZATION)
136
            logger.warning_once(
137
138
                "VLLM_CPU_KVCACHE_SPACE not set. Using %s GiB for KV cache.",
                format_gib(kv_cache_space),
139
            )
140
141
142
143
        else:
            kv_cache_space *= GiB_bytes

        return kv_cache_space
144

145
146
147
148
149
150
151
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cpu.set_device(device)

152
153
    @classmethod
    def inference_mode(cls):
154
        return torch.no_grad()
155
156
157
158
159

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        model_config = vllm_config.model_config

160
161
        if model_config is not None:
            model_config.disable_cascade_attn = True
162

163
164
        cache_config = vllm_config.cache_config

165
        if not cache_config.user_specified_block_size:
166
            cache_config.block_size = 128
167

168
169
170
171
        if cache_config.block_size % 32 != 0:
            logger.warning(
                "CPU backend prefers block_size is multiples of 32, "
                "otherwise the performance is not optimized."
172
            )
173

174
        scheduler_config = vllm_config.scheduler_config
175
176
        # async scheduling is not required on CPU
        scheduler_config.async_scheduling = False
177
        if (
178
            scheduler_config.enable_chunked_prefill
179
            or cache_config.enable_prefix_caching
180
        ) and is_quantized_kv_cache(cache_config.cache_dtype):
181
182
183
184
            raise RuntimeError(
                "Chunked-prefill and prefix-cache on the CPU "
                "backend is not compatible with FP8 KV cache."
            )
185

186
        if cache_config.cache_dtype.startswith("fp8"):
187
            logger.warning(
188
                "CPU backend doesn't support KV cache quantization fallback to auto."
189
            )
190
            cache_config.cache_dtype = "auto"
191

192
        cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()
193

194
195
196
197
198
199
200
        # reserve at least one core for nixl_connector under p/d case
        if vllm_config.kv_transfer_config and (
            envs.VLLM_CPU_NUM_OF_RESERVED_CPU == 0
            or envs.VLLM_CPU_NUM_OF_RESERVED_CPU is None
        ):
            os.environ["VLLM_CPU_NUM_OF_RESERVED_CPU"] = "1"

201
        parallel_config = vllm_config.parallel_config
202
203
204
205
206
207
208
209
210
211
212
213
        if (
            parallel_config.world_size > 1
            and parallel_config.distributed_executor_backend is not None
            and parallel_config.distributed_executor_backend != "mp"
        ):
            logger.warning(
                (
                    "%s is not supported on CPU, fallback to mp "
                    "distributed executor backend."
                ),
                parallel_config.distributed_executor_backend,
            )
214
            parallel_config.distributed_executor_backend = "mp"
215
        if parallel_config.worker_cls == "auto":
216
            parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
217
218
        # Disable DBO
        if parallel_config.enable_dbo:
219
            logger.warning("Dual-Batch Overlap is not supported on CPU, disabled.")
220
            parallel_config.enable_dbo = False
221
222

        # Note: workaround for v1 gpu_model_runner
223
        from vllm.config import CompilationMode
224

225
226
227
        vllm_config.compilation_config.cudagraph_capture_sizes = []

        compilation_config = vllm_config.compilation_config
228
        if vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
229
230
231
232
233
234
235
236
237
238
239
240
            # Note: vLLM V1 is using PIECEWISE level compilation, which will
            # take time to compile kernels just-in-time with the inductor
            # backend. For CPU CI tests, most of them are executed fast and
            # compilations consume too much time, even with torch compile
            # cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
            # and just execute model with dynamo + eager mode to save time.
            # VLLM_CPU_CI_ENV is only used as an internal variable.
            if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0":
                backend = "eager"
            else:
                backend = "inductor"

241
            compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
242
            compilation_config.backend = backend
243
244
245
246
247
248
249
250
            compilation_config.inductor_compile_config.update(
                {
                    "dce": True,
                    "size_asserts": False,
                    "nan_asserts": False,
                    "epilogue_fusion": True,
                }
            )
251
252

        if vllm_config.lora_config is not None:
253
            compilation_config.mode = CompilationMode.NONE
254

255
256
        vllm_config.profiler_config.torch_profiler_dump_cuda_time_total = False

257
258
259
260
261
262
        assert vllm_config.device_config.device_type == "cpu"

        #
        # Environment variables for CPU executor
        #

263
264
265
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        # Note: to avoid the error 'nthreads cannot be larger than environment
266
        # variable "NUMEXPR_MAX_THREADS" (64)'.
267
        os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads())
268

269
270
271
272
273
274
275
        if envs.VLLM_CPU_OMP_THREADS_BIND != "nobind":
            # Set default threads num for OpenMP parallel
            os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())
        else:
            # In this case, setting the OpenMP configuration via
            # OMP_NUM_THREADS is up to the user.
            logger.info("Disabling binding processes to CPU cores...")
276

277
278
279
        # Disable torch async compiling which won't work with daemonic processes
        os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

280
        # Disable multi-stream for shared experts as no Stream on CPU
281
        os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
282

283
        # Intel OpenMP setting
284
285
        ld_preload_str = os.getenv("LD_PRELOAD", "")
        if "libiomp5.so" in ld_preload_str:
286
287
            # The time(milliseconds) that a thread should wait after
            # completing the execution of a parallel region, before sleeping.
288
            os.environ["KMP_BLOCKTIME"] = "1"
289
            # Prevents the CPU to run into low performance state
290
            os.environ["KMP_TPAUSE"] = "0"
291
            # Provides fine granularity parallelism
292
293
294
            os.environ["KMP_FORKJOIN_BARRIER_PATTERN"] = "dist,dist"
            os.environ["KMP_PLAIN_BARRIER_PATTERN"] = "dist,dist"
            os.environ["KMP_REDUCTION_BARRIER_PATTERN"] = "dist,dist"
295

296
297
        if (
            platform.system() == "Linux"
298
299
            and Platform.get_cpu_architecture()
            in (CpuArchEnum.ARM, CpuArchEnum.POWERPC)
300
301
302
303
304
305
306
307
308
309
310
            and not ("libomp" in ld_preload_str or "libgomp" in ld_preload_str)
        ):
            # We need to LD_PRELOAD PyTorch's libgomp, otherwise only
            # one core will be properly utilized when we thread-bind
            # See: https://github.com/vllm-project/vllm/issues/27369
            # TODO: Remove once:
            # https://github.com/pytorch/pytorch/issues/166087 is fixed

            # We need to find the location of PyTorch's libgomp
            torch_pkg = os.path.dirname(torch.__file__)
            site_root = os.path.dirname(torch_pkg)
311
312
313
314
315
316
317
318
319
320
            # Search both torch.libs and torch/lib - See: https://github.com/vllm-project/vllm/issues/30470
            torch_libs_paths = [
                os.path.join(site_root, "torch.libs"),
                os.path.join(torch_pkg, "lib"),
            ]
            pytorch_libgomp_so_candidates = []
            for torch_libs in torch_libs_paths:
                pytorch_libgomp_so_candidates.extend(
                    glob.glob(os.path.join(torch_libs, "libgomp*.so*"))
                )
321
322
323
324
325
326
327
            if pytorch_libgomp_so_candidates:
                pytorch_libgomp_so = pytorch_libgomp_so_candidates[0]
                if ld_preload_str:
                    ld_preload_str += ":"
                ld_preload_str += pytorch_libgomp_so
                os.environ["LD_PRELOAD"] = ld_preload_str

328
        os.environ["LOCAL_WORLD_SIZE"] = str(
329
330
            vllm_config.parallel_config.tensor_parallel_size
        )
331

332
        if model_config is not None and model_config.use_mla:
333
334
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
335
336
                "prefill and prefix caching to be disabled."
            )
337
338
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
339
                vllm_config.model_config.max_model_len,
340
                vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
341
            )
342

343
344
345
346
347
348
    @classmethod
    def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
        # TODO: CPU still sets block_size in check_and_update_config.
        # Move that logic here so block_size is chosen by the backend.
        pass

349
    @classmethod
350
    def get_allowed_cpu_core_node_list(cls) -> tuple[list[int], list[LogicalCPUInfo]]:
351
352
353
        assert platform.system() == "Linux"

        # Init LogicalCPUInfo from lscpu
354
355
356
        lscpu_output = subprocess.check_output(
            "lscpu -J -e=CPU,CORE,NODE", shell=True, text=True
        )
357
        lscpu_output = re.sub(r'"node":\s*-\s*(,|\n)', r'"node": 0\1', lscpu_output)
358
        logical_cpu_list: list[LogicalCPUInfo] = json.loads(
359
360
            lscpu_output, object_hook=LogicalCPUInfo.json_decoder
        )["cpus"]
361
362
363

        # Filter CPUs with invalid attributes
        logical_cpu_list = [
364
365
            x
            for x in logical_cpu_list
366
367
368
369
            if -1 not in (x.id, x.physical_core, x.numa_node)
        ]

        # Filter allowed CPUs
370
371
372
373
374
        if hasattr(os, "sched_getaffinity"):
            allowed_cpu_id_list = os.sched_getaffinity(0)
        else:
            raise NotImplementedError("Unsupported OS")
        logical_cpu_list = [x for x in logical_cpu_list if x.id in allowed_cpu_id_list]
375
376
377
378
379
380
381

        # Get allowed NUMA nodes
        allowed_numa_nodes = set()
        for x in logical_cpu_list:
            allowed_numa_nodes.add(x.numa_node)  # type: ignore
        allowed_numa_nodes_list = sorted(allowed_numa_nodes)

382
        env_key = CpuPlatform.device_control_env_var
383
384
        if env_key in os.environ and os.environ[env_key] != "":
            visible_nodes = [int(s) for s in os.environ[env_key].split(",")]
385
            allowed_numa_nodes_list = [
386
                x for x in sorted(list(set(visible_nodes))) if x in allowed_numa_nodes
387
388
            ]

389
390
        return allowed_numa_nodes_list, logical_cpu_list

391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    @classmethod
    def discover_numa_topology(cls) -> list[list[int]]:
        """
        Discover NUMA topology and keep the last physical core of each numa
        into one core group list for nixl start_kv_load()
        """
        SYS_NODE = "/sys/devices/system/node"
        SYS_CPU = "/sys/devices/system/cpu"

        if not (os.path.exists(SYS_NODE) and os.path.exists(SYS_CPU)):
            return []

        core_rsv_for_kv = []
        for node in os.listdir(SYS_NODE):
            if not node.startswith("node") or not node[4:].isdigit():
                continue
            node_path = f"{SYS_NODE}/{node}"

            seen_phys = set()
            for cpu in os.listdir(node_path):
                if not cpu.startswith("cpu") or not cpu[3:].isdigit():
                    continue

                cpu_id = int(cpu[3:])
                # thread_siblings based on cpu_id
                path = f"{SYS_CPU}/cpu{cpu_id}/topology/thread_siblings_list"

                if os.path.exists(path):
                    try:
                        with open(path) as f:
                            s = f.read()
                        cpus: list[int] = []
                        for part in s.strip().split(","):
                            if "-" in part:
                                a, b = map(int, part.split("-"))
                                cpus.extend(range(a, b + 1))
                            else:
                                cpus.append(int(part))
                        siblings = cpus if cpus else [cpu_id]
                    except (OSError, ValueError):
                        siblings = [cpu_id]
                else:
                    siblings = [cpu_id]

                phys = min(siblings)

                if phys not in seen_phys:
                    seen_phys.add(phys)

            if len(seen_phys) > 0:
                core_rsv_for_kv.append(list(seen_phys))

        return core_rsv_for_kv

445
446
447
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        return False
448
449
450
451

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
452
453
454
455
456
457
458

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
        return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"  # noqa
459
460
461
462
463

    @classmethod
    def supports_structured_output(cls) -> bool:
        return True

464
465
466
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
467
468
469
470

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True
471
472
473
474

    @classmethod
    def import_kernels(cls) -> None:
        if Platform.get_cpu_architecture() in (CpuArchEnum.X86,):
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
            # Note: The lib name is _C_AVX2/AVX512, but the module name is _C.
            # This will cause a exception "dynamic module does define
            # module export function". But the library is imported
            # successfully. So ignore the exception for now, until we find
            # a solution.
            ignored_msg = "dynamic module does not define module export function"
            if torch.cpu._is_avx512_supported():
                if torch.cpu._is_avx512_bf16_supported():
                    try:
                        import vllm._C  # noqa: F401
                    except ImportError as e:
                        logger.warning("Failed to import from vllm._C: %r", e)
                else:
                    try:
                        import vllm._C_AVX512  # noqa: F401
                    except ImportError as e:
                        if ignored_msg not in e.msg:
                            logger.warning(
                                "Failed to import from vllm._C_AVX512: %r", e
                            )
495
496
497
498
            else:
                try:
                    import vllm._C_AVX2  # noqa: F401
                except ImportError as e:
499
500
                    if ignored_msg not in e.msg:
                        logger.warning("Failed to import from vllm._C_AVX2: %r", e)
501
502
503
504
505
        else:
            try:
                import vllm._C  # noqa: F401
            except ImportError as e:
                logger.warning("Failed to import from vllm._C: %r", e)