cpu.py 7.21 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
import sys
5
from importlib.util import find_spec
6
from typing import TYPE_CHECKING, Optional
7

8
import psutil
9
10
import torch

11
12
from vllm.logger import init_logger

13
14
15
from .interface import Platform, PlatformEnum, _Backend

logger = init_logger(__name__)
16

17
18
19
20
21
22
23
if TYPE_CHECKING:
    from vllm.config import VllmConfig
else:
    VllmConfig = None

logger = init_logger(__name__)

24
25
26

class CpuPlatform(Platform):
    _enum = PlatformEnum.CPU
27
    device_name: str = "cpu"
28
    device_type: str = "cpu"
29
    dispatch_key: str = "CPU"
30

31
32
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
33
34
        return "cpu"

35
    @classmethod
36
37
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
38
39
                             block_size: int, use_v1: bool,
                             use_mla: bool) -> str:
40
        if selected_backend and selected_backend != _Backend.TORCH_SDPA:
41
            logger.info("Cannot use %s backend on CPU.", selected_backend)
Thien Tran's avatar
Thien Tran committed
42
43
44
        if use_mla:
            logger.info("Using CPU MLA backend.")
            return "vllm.attention.backends.cpu_mla.CPUMLABackend"
45
46
        logger.info("Using Torch SDPA backend.")
        return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
47

48
49
50
51
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        return psutil.virtual_memory().total

52
53
54
55
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
        return False

56
57
    @classmethod
    def inference_mode(cls):
58
        return torch.no_grad()
59
60
61
62
63
64

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
        import vllm.envs as envs
        from vllm.utils import GiB_bytes
        model_config = vllm_config.model_config
65
        # Reminder: Please update docs/source/features/compatibility_matrix.md
66
67
68
69
70
71
        # If the feature combo become valid
        if not model_config.enforce_eager:
            model_config.enforce_eager = True

        cache_config = vllm_config.cache_config

72
        ipex_available = find_spec("intel_extension_for_pytorch") is not None
73

74
        if cache_config and cache_config.block_size is None:
75
            cache_config.block_size = 128 if ipex_available else 16
76

77
        if not ipex_available and cache_config.block_size != 16:
78
79
80
            raise RuntimeError(
                f"--block-size={cache_config.block_size} requires"
                " intel_extension_for_pytorch")
81

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        scheduler_config = vllm_config.scheduler_config
        if ((scheduler_config.chunked_prefill_enabled
             or cache_config.enable_prefix_caching)
                and cache_config.cache_dtype != "auto"):
            raise RuntimeError("Chunked-prefill and prefix-cache on the CPU "
                               "backend is not compatible with FP8 KV cache.")

        if cache_config.cache_dtype == "fp8_e4m3":
            cache_config.cache_dtype = "fp8_e5m2"
            logger.warning(
                "CPU backend doesn't support fp8_e4m3 KV cache type, "
                "cast to fp8_e5m2.")

        if (cache_config.cache_dtype != "auto"
                and model_config.dtype == torch.half):
            logger.warning("FP8 KV cache on the CPU backend only does not"
                           " support fp16 for now, cast to bf16.")
            model_config.dtype = torch.bfloat16

101
102
103
104
105
106
        kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE

        if kv_cache_space >= 0:
            if kv_cache_space == 0:
                cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes  # type: ignore
                logger.warning(
107
                    "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) "
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
                    "for CPU backend is not set, using 4 by default.")
            else:
                cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes  # type: ignore # noqa
        else:
            raise RuntimeError(
                "Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
                f" {kv_cache_space}, expect a positive integer value.")

        parallel_config = vllm_config.parallel_config
        if (parallel_config.distributed_executor_backend is not None
                and parallel_config.distributed_executor_backend != "mp"):
            logger.warning(("%s is not supported on CPU, fallback to mp "
                            "distributed executor backend."),
                           parallel_config.distributed_executor_backend)
            parallel_config.distributed_executor_backend = "mp"
123
        if parallel_config.worker_cls == "auto":
124
125
126
127
128
129
130
            if vllm_config.speculative_config:
                parallel_config.worker_cls = \
                    "vllm.spec_decode.spec_decode_worker.create_spec_worker"
                parallel_config.sd_worker_cls = \
                    "vllm.worker.cpu_worker.CPUWorker"
            else:
                parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
131

132
133
134
135
136
137
        assert vllm_config.device_config.device_type == "cpu"

        #
        # Environment variables for CPU executor
        #

138
139
140
        # Set default threads num for OpenMP parallel
        os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        # Disable torch async compiling which won't work with daemonic processes
        os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

        # Intel OpenMP setting
        ld_prealod_str = os.getenv("LD_PRELOAD", "")
        if "libiomp5.so" in ld_prealod_str:
            # The time(milliseconds) that a thread should wait after
            # completing the execution of a parallel region, before sleeping.
            os.environ['KMP_BLOCKTIME'] = "1"
            # Prevents the CPU to run into low performance state
            os.environ['KMP_TPAUSE'] = "0"
            # Provides fine granularity parallelism
            os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
            os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
            os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"

        # To hint IPEX uses shared memory based AllReduce
        os.environ["LOCAL_WORLD_SIZE"] = str(
            vllm_config.parallel_config.tensor_parallel_size)
160
161
162
163
164
165
166
        if sys.platform == "darwin" and \
                envs.VLLM_WORKER_MULTIPROC_METHOD == "fork":
            if os.environ.get('VLLM_WORKER_MULTIPROC_METHOD', None) is None:
                logger.warning(
                    "Default to spawn method on MacOS. If this is not desired,"
                    " set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.")
                os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
167

168
169
170
171
    @classmethod
    def is_pin_memory_available(cls) -> bool:
        logger.warning("Pin memory is not supported on CPU.")
        return False
172
173
174
175

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
176
177
178
179
180
181
182

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        """
        Get device specific communicator class for distributed communication.
        """
        return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"  # noqa
183
184
185
186

    @classmethod
    def supports_structured_output(cls) -> bool:
        return True