"tests/entrypoints/openai/test_chat.py" did not exist on "01bfb22b4112ee813185366ab26985d172661a61"
cpu.py 14.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
import os
6
import platform
7
import subprocess
8
import sys
9
from dataclasses import dataclass
10
from importlib.util import find_spec
11
from typing import TYPE_CHECKING
12

13
import regex as re
14
15
import torch

16
from vllm.logger import init_logger
17
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
18

19
from .interface import CpuArchEnum, Platform, PlatformEnum
20
21

logger = init_logger(__name__)
22

23
if TYPE_CHECKING:
24
    from vllm.attention.backends.registry import _Backend
25
26
    from vllm.config import VllmConfig
else:
27
    _Backend = None
28
29
    VllmConfig = None

30

31
def get_max_threads(pid=0):
32
    if hasattr(os, "sched_getaffinity"):
33
        return len(os.sched_getaffinity(pid))
34
    elif platform.system() == "Darwin":
35
36
37
38
39
        return os.cpu_count()
    else:
        raise NotImplementedError("Unsupported OS")


40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@dataclass
class LogicalCPUInfo:
    id: int = -1
    physical_core: int = -1
    numa_node: int = -1

    @classmethod
    def _int(cls, value: str) -> int:
        try:
            int_value = int(value)
        except Exception:
            int_value = -1
        return int_value

    @staticmethod
    def json_decoder(obj_dict: dict):
        id = obj_dict.get("cpu")
        physical_core = obj_dict.get("core")
        numa_node = obj_dict.get("node")

        if not (id is None or physical_core is None or numa_node is None):
            return LogicalCPUInfo(
                id=LogicalCPUInfo._int(id),
                physical_core=LogicalCPUInfo._int(physical_core),
64
65
                numa_node=LogicalCPUInfo._int(numa_node),
            )
66
67
68
69
        else:
            return obj_dict


70
71
class CpuPlatform(Platform):
    _enum = PlatformEnum.CPU
72
    device_name: str = "cpu"
73
    device_type: str = "cpu"
74
    dispatch_key: str = "CPU"
75
    dist_backend: str = "gloo"
76
    device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
77

78
    @property
79
    def supported_dtypes(self) -> list[torch.dtype]:
80
81
        if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
            return [torch.bfloat16, torch.float32]
82
83
84
85
86
87
88
89
90
        elif self.get_cpu_architecture() == CpuArchEnum.ARM and sys.platform.startswith(
            "darwin"
        ):
            if (
                subprocess.check_output(
                    ["sysctl -n hw.optional.arm.FEAT_BF16"], shell=True
                ).strip()
                == b"1"
            ):
91
                return [torch.bfloat16, torch.float16, torch.float32]
92
            return [torch.float16, torch.float32]
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        elif self.get_cpu_architecture() == CpuArchEnum.RISCV:
            # Workaround for Issue #25655: RISC-V scheduler bug with float16
            #
            # Background:
            # - RISC-V currently uses scalar code path
            # - There is a latent bug in the vLLM scheduler that provides
            # invalid
            #   physical_block_idx values under certain conditions
            # - This bug causes segmentation faults when using float16
            # dtype on RISC-V
            # - Testing shows that forcing float32 successfully bypasses
            # this issue
            #
            # Technical details:
            # - The bug manifests as out-of-bounds physical_block_idx in
            # block_tables
            # - Only occurs on RISC-V hardware
            # tested on Sophgo SG2044
            # - Does not reproduce on x86 or other architectures
            # - Root cause is in Python-level scheduling logic,
            # not C++ kernels
            #
            # This is a temporary workaround until the scheduler bug is fixed.
            # See: https://github.com/vllm-project/vllm/issues/25655
            return [torch.float32]
118
119
120
        # x86/aarch64 CPU has supported both bf16 and fp16 natively.
        return [torch.bfloat16, torch.float16, torch.float32]

121
122
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
123
124
        return "cpu"

125
    @classmethod
126
127
128
129
130
    def get_attn_backend_cls(
        cls,
        selected_backend: "_Backend",
        head_size: int,
        dtype: torch.dtype,
131
        kv_cache_dtype: str | None,
132
133
134
135
136
137
        block_size: int,
        use_v1: bool,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
    ) -> str:
138
        from vllm.attention.backends.registry import _Backend
139

140
        if selected_backend and selected_backend != _Backend.TORCH_SDPA:
141
            logger.info("Cannot use %s backend on CPU.", selected_backend)
Thien Tran's avatar
Thien Tran committed
142
        if use_mla:
143
            raise NotImplementedError("MLA is not supported on CPU.")
144
        if use_sparse:
145
            raise NotImplementedError("Sparse Attention is not supported on CPU.")
146
        logger.info("Using Torch SDPA backend.")
147
148
149
        if not use_v1:
            raise ValueError("CPU backend only supports V1.")
        return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
150

151
152
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
153
        import vllm.envs as envs
154
        from vllm.utils.mem_constants import GiB_bytes
155
156
157
158
159
160

        kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
        if kv_cache_space is None:
            kv_cache_space = 4 * GiB_bytes  # type: ignore
            logger.warning_once(
                "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
161
162
                "for CPU backend is not set, using 4 by default."
            )
163
164
165
166
        else:
            kv_cache_space *= GiB_bytes

        return kv_cache_space
167

168
169
170
171
172
173
174
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cpu.set_device(device)

175
176
    @classmethod
    def inference_mode(cls):
177
        return torch.no_grad()
178
179
180
181
182

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        model_config = vllm_config.model_config

183
184
        if model_config is not None:
            model_config.disable_cascade_attn = True
185

186
187
        cache_config = vllm_config.cache_config

188
        ipex_available = find_spec("intel_extension_for_pytorch") is not None
189

190
        if cache_config and cache_config.block_size is None:
191
            cache_config.block_size = 128 if ipex_available else 16
192

193
        if not ipex_available and cache_config.block_size != 16:
194
195
            raise RuntimeError(
                f"--block-size={cache_config.block_size} requires"
196
197
                " intel_extension_for_pytorch"
            )
198

199
        scheduler_config = vllm_config.scheduler_config
200
201
202
203
204
205
206
207
        if (
            scheduler_config.chunked_prefill_enabled
            or cache_config.enable_prefix_caching
        ) and cache_config.cache_dtype != "auto":
            raise RuntimeError(
                "Chunked-prefill and prefix-cache on the CPU "
                "backend is not compatible with FP8 KV cache."
            )
208
209
210
211

        if cache_config.cache_dtype == "fp8_e4m3":
            cache_config.cache_dtype = "fp8_e5m2"
            logger.warning(
212
213
214
215
216
217
218
219
220
221
222
223
                "CPU backend doesn't support fp8_e4m3 KV cache type, cast to fp8_e5m2."
            )

        if (
            cache_config.cache_dtype != "auto"
            and model_config is not None
            and model_config.dtype == torch.half
        ):
            logger.warning(
                "FP8 KV cache on the CPU backend only does not"
                " support fp16 for now, cast to bf16."
            )
224
225
            model_config.dtype = torch.bfloat16

226
        cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()
227
228

        parallel_config = vllm_config.parallel_config
229
230
231
232
233
234
235
236
237
238
239
240
        if (
            parallel_config.world_size > 1
            and parallel_config.distributed_executor_backend is not None
            and parallel_config.distributed_executor_backend != "mp"
        ):
            logger.warning(
                (
                    "%s is not supported on CPU, fallback to mp "
                    "distributed executor backend."
                ),
                parallel_config.distributed_executor_backend,
            )
241
            parallel_config.distributed_executor_backend = "mp"
242
        if parallel_config.worker_cls == "auto":
243
            parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
244
245
        # Disable DBO
        if parallel_config.enable_dbo:
246
            logger.warning("Dual-Batch Overlap is not supported on CPU, disabled.")
247
            parallel_config.enable_dbo = False
248
249

        # Note: workaround for v1 gpu_model_runner
250
        from vllm.config import CompilationMode
251

252
253
254
        vllm_config.compilation_config.cudagraph_capture_sizes = []

        compilation_config = vllm_config.compilation_config
255
        if vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
256
257
258
259
260
261
262
263
264
265
266
267
            # Note: vLLM V1 is using PIECEWISE level compilation, which will
            # take time to compile kernels just-in-time with the inductor
            # backend. For CPU CI tests, most of them are executed fast and
            # compilations consume too much time, even with torch compile
            # cache. So use VLLM_CPU_CI_ENV to indicate the CI environment,
            # and just execute model with dynamo + eager mode to save time.
            # VLLM_CPU_CI_ENV is only used as an internal variable.
            if os.environ.get("VLLM_CPU_CI_ENV", "0") != "0":
                backend = "eager"
            else:
                backend = "inductor"

268
            compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
269
            compilation_config.backend = backend
270
271
272
273
274
275
276
277
            compilation_config.inductor_compile_config.update(
                {
                    "dce": True,
                    "size_asserts": False,
                    "nan_asserts": False,
                    "epilogue_fusion": True,
                }
            )
278
279

        if vllm_config.lora_config is not None:
280
            compilation_config.mode = CompilationMode.NONE
281

282
283
284
285
286
287
        assert vllm_config.device_config.device_type == "cpu"

        #
        # Environment variables for CPU executor
        #

288
289
290
291
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

        # Note: to avoid the error 'nthreads cannot be larger than environment
        #  variable "NUMEXPR_MAX_THREADS" (64)'.
292
        os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads())
293

294
295
296
        # Set default threads num for OpenMP parallel
        os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())

297
298
299
        # Disable torch async compiling which won't work with daemonic processes
        os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

300
301
302
        # Disable multi-stream for shared experts as no Stream on CPU
        os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "0"

303
304
305
306
307
        # Intel OpenMP setting
        ld_prealod_str = os.getenv("LD_PRELOAD", "")
        if "libiomp5.so" in ld_prealod_str:
            # The time(milliseconds) that a thread should wait after
            # completing the execution of a parallel region, before sleeping.
308
            os.environ["KMP_BLOCKTIME"] = "1"
309
            # Prevents the CPU to run into low performance state
310
            os.environ["KMP_TPAUSE"] = "0"
311
            # Provides fine granularity parallelism
312
313
314
            os.environ["KMP_FORKJOIN_BARRIER_PATTERN"] = "dist,dist"
            os.environ["KMP_PLAIN_BARRIER_PATTERN"] = "dist,dist"
            os.environ["KMP_REDUCTION_BARRIER_PATTERN"] = "dist,dist"
315
316
317

        # To hint IPEX uses shared memory based AllReduce
        os.environ["LOCAL_WORLD_SIZE"] = str(
318
319
            vllm_config.parallel_config.tensor_parallel_size
        )
320

321
        if model_config is not None and model_config.use_mla:
322
323
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
324
325
                "prefill and prefix caching to be disabled."
            )
326
327
328
329
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
330
331
                DEFAULT_MAX_NUM_BATCHED_TOKENS,
            )
332

333
    @classmethod
334
    def get_allowed_cpu_core_node_list(cls) -> tuple[list[int], list[LogicalCPUInfo]]:
335
336
337
        assert platform.system() == "Linux"

        # Init LogicalCPUInfo from lscpu
338
339
340
        lscpu_output = subprocess.check_output(
            "lscpu -J -e=CPU,CORE,NODE", shell=True, text=True
        )
341
        lscpu_output = re.sub(r'"node":\s*-\s*(,|\n)', r'"node": 0\1', lscpu_output)
342
        logical_cpu_list: list[LogicalCPUInfo] = json.loads(
343
344
            lscpu_output, object_hook=LogicalCPUInfo.json_decoder
        )["cpus"]
345
346
347

        # Filter CPUs with invalid attributes
        logical_cpu_list = [
348
349
            x
            for x in logical_cpu_list
350
351
352
353
            if -1 not in (x.id, x.physical_core, x.numa_node)
        ]

        # Filter allowed CPUs
354
355
356
357
358
        if hasattr(os, "sched_getaffinity"):
            allowed_cpu_id_list = os.sched_getaffinity(0)
        else:
            raise NotImplementedError("Unsupported OS")
        logical_cpu_list = [x for x in logical_cpu_list if x.id in allowed_cpu_id_list]
359
360
361
362
363
364
365

        # Get allowed NUMA nodes
        allowed_numa_nodes = set()
        for x in logical_cpu_list:
            allowed_numa_nodes.add(x.numa_node)  # type: ignore
        allowed_numa_nodes_list = sorted(allowed_numa_nodes)

366
        env_key = CpuPlatform.device_control_env_var
367
368
        if env_key in os.environ and os.environ[env_key] != "":
            visible_nodes = [int(s) for s in os.environ[env_key].split(",")]
369
370
371
372
            allowed_numa_nodes_list = [
                x for x in visible_nodes if x in allowed_cpu_id_list
            ]

373
374
        return allowed_numa_nodes_list, logical_cpu_list

375
376
377
378
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        logger.warning("Pin memory is not supported on CPU.")
        return False
379
380
381
382

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
383
384
385
386
387
388
389

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
        return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"  # noqa
390
391
392
393
394

    @classmethod
    def supports_structured_output(cls) -> bool:
        return True

395
396
397
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
398
399
400
401

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True