"vscode:/vscode.git/clone" did not exist on "df21a9254264b11855d5b16196b21869841da9d9"
xpu.py 9.54 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import os
6
from typing import TYPE_CHECKING, Optional
7

8
9
import torch

10
import vllm.envs as envs
11
from vllm.logger import init_logger
12
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
13

14
from .interface import DeviceCapability, Platform, PlatformEnum
15

16
if TYPE_CHECKING:
17
    from vllm.attention.backends.registry import _Backend
18
    from vllm.config import ModelConfig, VllmConfig
19
else:
20
    ModelConfig = None
21
    VllmConfig = None
22
    _Backend = None
23

24
logger = init_logger(__name__)
25
26
27
28


class XPUPlatform(Platform):
    _enum = PlatformEnum.XPU
29
    device_name: str = "xpu"
30
    device_type: str = "xpu"
31
    dispatch_key: str = "XPU"
32
33
34
    # Intel XPU's device key is "GPU" for Ray.
    # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
    ray_device_key: str = "GPU"
35
    dist_backend: str = "ccl"  # ccl | xccl
36
    device_control_env_var: str = "ZE_AFFINITY_MASK"
37

38
    @classmethod
39
40
41
42
    def import_kernels(cls) -> None:
        # Do not import vllm._C
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
43

44
    @classmethod
45
46
47
48
49
50
51
52
53
54
55
56
    def get_attn_backend_cls(
        cls,
        selected_backend: "_Backend",
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: Optional[str],
        block_size: int,
        use_v1: bool,
        use_mla: bool,
        has_sink: bool,
        use_sparse,
    ) -> str:
57
        from vllm.attention.backends.registry import _Backend
58

59
        if use_sparse:
60
            raise NotImplementedError("Sparse Attention is not supported on XPU.")
61
        use_v1 = envs.VLLM_USE_V1
62
63
        if not use_v1:
            raise ValueError("XPU backend only supports V1.")
64
65
66
        TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"  # noqa: E501
        FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
        if selected_backend == _Backend.TRITON_ATTN:
67
            logger.info_once("Using Triton backend on V1 engine.")
68
            return TRITON_ATTN
69
70
        elif selected_backend == _Backend.FLASH_ATTN:
            logger.info_once("Using Flash Attention backend on V1 engine.")
71
            return FLASH_ATTN
72
73
74
        elif selected_backend:
            raise ValueError(
                f"Invalid attention backend for {cls.device_name}, "
75
76
                f"with use_v1: {use_v1} use_mla: {use_mla}"
            )
77

78
79
        logger.info("Using Flash Attention backend on V1 engine.")
        return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
80

81
    @classmethod
82
83
84
    def is_kv_cache_dtype_supported(
        cls, kv_cache_dtype: str, model_config: "ModelConfig"
    ) -> bool:
85
86
87
88
        """
        Check if the kv_cache_dtype is supported.
        XPU only support fp8 kv cache with triton backend.
        """
89
90
91
92
        if (
            envs.is_set("VLLM_ATTENTION_BACKEND")
            and envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN"
        ):
93
94
95
96
            return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]

        return False

97
98
99
100
101
102
103
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.xpu.set_device(device)

104
    @classmethod
105
    def get_device_capability(
106
107
108
        cls,
        device_id: int = 0,
    ) -> Optional[DeviceCapability]:
109
110
111
        # capacity format differs from cuda's and will cause unexpected
        # failure, so use None directly
        return None
112

113
114
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
115
        return torch.xpu.get_device_name(device_id)
116

117
118
    @classmethod
    def get_punica_wrapper(cls) -> str:
119
        return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU"
120

121
122
123
124
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.xpu.get_device_properties(device_id)
        return device_props.total_memory
125

126
127
    @classmethod
    def inference_mode(cls):
128
        return torch.no_grad()
129
130
131

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
132
        cache_config = vllm_config.cache_config
133
        model_config = vllm_config.model_config
134
        # in V1(or with ipex chunked prefill) block_size is 64
135
        if cache_config and cache_config.block_size is None:
136
            cache_config.block_size = 64
137

138
        # lazy import to avoid circular import
139
        from vllm.config import CompilationLevel, CUDAGraphMode
140

141
        compilation_config = vllm_config.compilation_config
142
143
        if compilation_config.compile_sizes is None:
            compilation_config.compile_sizes = []
144

145
        assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, (
146
            "CUDA graph mode should be NONE on XPU"
147
        )
148

149
150
151
        if vllm_config.lora_config is not None:
            compilation_config.level = CompilationLevel.NO_COMPILATION

152
153
        # check and update parallel config
        parallel_config = vllm_config.parallel_config
154
        parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
155
156

        if parallel_config.distributed_executor_backend is None:
157
158
159
160
            if parallel_config.world_size > 1:
                parallel_config.distributed_executor_backend = "ray"
            else:
                parallel_config.distributed_executor_backend = "uni"
161
162
163
164
        elif parallel_config.distributed_executor_backend == "mp":
            # FIXME(kunshang):
            # spawn needs calling `if __name__ == '__main__':``
            # fork is not supported for xpu start new process.
165
166
167
            if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn":
                os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
                logger.warning(
168
169
170
171
172
173
174
                    "Please use spawn as start method if you want to use mp."
                )
        elif (
            parallel_config.distributed_executor_backend != "ray"
            and parallel_config.distributed_executor_backend != "uni"
            and parallel_config.distributed_executor_backend != "external_launcher"
        ):
175
176
177
            logger.warning(
                "%s is not supported on XPU, fallback to ray distributed"
                " executor backend.",
178
179
                parallel_config.distributed_executor_backend,
            )
180
            parallel_config.distributed_executor_backend = "ray"
181

182
        if model_config and model_config.use_mla:
183
184
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
185
186
                "prefill and prefix caching to be disabled."
            )
187
188
189
190
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.chunked_prefill_enabled = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
                vllm_config.scheduler_config.max_model_len,
191
192
                DEFAULT_MAX_NUM_BATCHED_TOKENS,
            )
193
        from vllm.v1.attention.backends.utils import set_kv_cache_layout
194

195
        set_kv_cache_layout("NHD")
196
197
198
199
        logger.info(
            "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
            "only NHD layout is supported by XPU attention kernels."
        )
200

201
202
203
204
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

205
206
207
208
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return False

209
210
    @classmethod
    def is_pin_memory_available(cls):
211
        return True
212
213

    @classmethod
214
215
216
    def get_current_memory_usage(
        cls, device: Optional[torch.types.Device] = None
    ) -> float:
217
218
        torch.xpu.reset_peak_memory_stats(device)
        return torch.xpu.max_memory_allocated(device)
219

220
221
222
223
    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        return torch.float8_e5m2

224
225
226
227
228
    @classmethod
    def is_data_center_gpu(cls) -> bool:
        device_name = cls.get_device_name().lower()
        return device_name.count("data center gpu") > 0

229
230
231
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"  # noqa
232
233
234

    @classmethod
    def device_count(cls) -> int:
235
        return torch.xpu.device_count()
236
237
238
239
240
241
242
243
244
245

    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            device_name = cls.get_device_name().lower()
            # client gpu a770
            if device_name.count("a770") > 0:
                raise ValueError(
                    "Intel Arc A770 have bfloat16 accuracy known issue. "
                    "You can use float16 instead by explicitly setting the "
246
247
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
248
249
250
251

    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on XPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from XPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()