xpu.py 9.24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import os
6
from typing import TYPE_CHECKING
7

8
9
import torch

10
from vllm.logger import init_logger
11
from vllm.v1.attention.backends.registry import AttentionBackendEnum
12

13
from .interface import DeviceCapability, Platform, PlatformEnum
14

15
if TYPE_CHECKING:
16
    from vllm.config import VllmConfig
17
    from vllm.v1.attention.selector import AttentionSelectorConfig
18
19
20
else:
    VllmConfig = None

21
logger = init_logger(__name__)
22
23
24
25


class XPUPlatform(Platform):
    _enum = PlatformEnum.XPU
26
    device_name: str = "xpu"
27
    device_type: str = "xpu"
28
    dispatch_key: str = "XPU"
29
30
31
    # Intel XPU's device key is "GPU" for Ray.
    # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
    ray_device_key: str = "GPU"
32
    dist_backend: str = "ccl"  # ccl | xccl
33
    device_control_env_var: str = "ZE_AFFINITY_MASK"
34

35
    @classmethod
36
37
38
39
    def import_kernels(cls) -> None:
        # Do not import vllm._C
        with contextlib.suppress(ImportError):
            import vllm._moe_C  # noqa: F401
40

41
    @classmethod
42
43
    def get_attn_backend_cls(
        cls,
44
        selected_backend: "AttentionBackendEnum",
45
        attn_selector_config: "AttentionSelectorConfig",
46
    ) -> str:
47
48
49
50
51
52
53
54
        from vllm.v1.attention.backends.utils import set_kv_cache_layout

        set_kv_cache_layout("NHD")
        logger.info(
            "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
            "only NHD layout is supported by XPU attention kernels."
        )

55
        dtype = attn_selector_config.dtype
56
        if attn_selector_config.use_sparse:
57
            raise NotImplementedError("Sparse Attention is not supported on XPU.")
58
        if selected_backend == AttentionBackendEnum.TRITON_ATTN:
59
            logger.info_once("Using Triton backend.")
60
            return AttentionBackendEnum.TRITON_ATTN.get_path()
61
62
63
64
65
66
        elif dtype == torch.float32:
            logger.warning_once(
                "Flash Attention on XPU does not support float32 dtype. "
                "Falling back to Triton Attention backend."
            )
            return AttentionBackendEnum.TRITON_ATTN.get_path()
67
        elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
68
            logger.info_once("Using Flash Attention backend.")
69
            return AttentionBackendEnum.FLASH_ATTN.get_path()
70
71
72
        elif selected_backend:
            raise ValueError(
                f"Invalid attention backend for {cls.device_name}, "
73
                f"with use_mla: {attn_selector_config.use_mla}"
74
            )
75

76
        logger.info("Using Flash Attention backend.")
77
        return AttentionBackendEnum.FLASH_ATTN.get_path()
78

79
80
81
82
83
84
85
86
87
88
89
90
    @classmethod
    def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
        # XPU only supports FLASH_ATTN for vision attention.
        return [
            AttentionBackendEnum.FLASH_ATTN,
        ]

    @classmethod
    def get_vit_attn_backend(
        cls,
        head_size: int,
        dtype: torch.dtype,
91
        backend: "AttentionBackendEnum | None" = None,
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    ) -> "AttentionBackendEnum":
        if backend is not None:
            assert backend in cls.get_supported_vit_attn_backends(), (
                f"Backend {backend} is not supported for vit attention. "
                f"Supported backends are: "
                f"{cls.get_supported_vit_attn_backends()}."
            )
            logger.info_once(f"Using backend {backend} for vit attention")
            return backend

        logger.info_once(
            f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention"
        )
        return AttentionBackendEnum.FLASH_ATTN

107
108
109
110
111
112
113
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.xpu.set_device(device)

114
    @classmethod
115
    def get_device_capability(
116
117
        cls,
        device_id: int = 0,
118
    ) -> DeviceCapability | None:
119
120
121
        # capacity format differs from cuda's and will cause unexpected
        # failure, so use None directly
        return None
122

123
124
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
125
        return torch.xpu.get_device_name(device_id)
126

127
128
    @classmethod
    def get_punica_wrapper(cls) -> str:
129
130
131
132
133
        xpu_use_triton_kernel = os.getenv("XPU_USE_TRITON_KERNEL", "0") == "1"
        if not xpu_use_triton_kernel:
            return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU"
        else:
            return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
134

135
136
137
138
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.xpu.get_device_properties(device_id)
        return device_props.total_memory
139

140
141
    @classmethod
    def inference_mode(cls):
142
        return torch.no_grad()
143
144
145

    @classmethod
    def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
146
        cache_config = vllm_config.cache_config
147
        model_config = vllm_config.model_config
148
        # in V1(or with ipex chunked prefill) block_size is 64
149
        if cache_config and cache_config.block_size is None:
150
            cache_config.block_size = 64
151

152
        # lazy import to avoid circular import
153
        from vllm.config import CompilationMode, CUDAGraphMode
154

155
        compilation_config = vllm_config.compilation_config
156
157
        if compilation_config.compile_sizes is None:
            compilation_config.compile_sizes = []
158

159
        assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, (
160
            "CUDA graph mode should be NONE on XPU"
161
        )
162

163
        if vllm_config.lora_config is not None:
164
            compilation_config.mode = CompilationMode.NONE
165
166
167
        # decrease triton kernel compilation scratch space for speculative decoding
        if vllm_config.speculative_config is not None:
            os.environ["IGC_ForceOCLSIMDWidth"] = "16"  # noqa: SIM112
168
169
        # check and update parallel config
        parallel_config = vllm_config.parallel_config
170
171
172
173
        # Only override worker_cls if it's still the default "auto"
        # This allows custom workers (like vllm-omni workers) to be used on XPU
        if parallel_config.worker_cls == "auto":
            parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
174
175
        if vllm_config.kv_transfer_config is not None:
            vllm_config.kv_transfer_config.enable_permute_local_kv = True
176

177
        if model_config and model_config.use_mla:
178
179
            logger.info(
                "MLA is enabled on a non-GPU platform; forcing chunked "
180
181
                "prefill and prefix caching to be disabled."
            )
182
183
            vllm_config.scheduler_config.enable_chunked_prefill = False
            vllm_config.scheduler_config.max_num_batched_tokens = max(
184
                vllm_config.model_config.max_model_len,
185
                vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
186
            )
187

188
189
190
191
    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True

192
193
194
195
    @classmethod
    def support_static_graph_mode(cls) -> bool:
        return False

196
197
    @classmethod
    def is_pin_memory_available(cls):
198
        return True
199
200

    @classmethod
201
    def get_current_memory_usage(
202
        cls, device: torch.types.Device | None = None
203
    ) -> float:
204
205
        torch.xpu.reset_peak_memory_stats(device)
        return torch.xpu.max_memory_allocated(device)
206

207
208
209
210
    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        return torch.float8_e5m2

211
212
213
214
215
    @classmethod
    def is_data_center_gpu(cls) -> bool:
        device_name = cls.get_device_name().lower()
        return device_name.count("data center gpu") > 0

216
217
218
    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"  # noqa
219
220
221

    @classmethod
    def device_count(cls) -> int:
222
        return torch.xpu.device_count()
223
224

    @classmethod
225
226
    def check_if_supports_dtype(cls, dtype: torch.dtype):
        if dtype == torch.bfloat16:  # noqa: SIM102
227
228
229
230
231
232
            device_name = cls.get_device_name().lower()
            # client gpu a770
            if device_name.count("a770") > 0:
                raise ValueError(
                    "Intel Arc A770 have bfloat16 accuracy known issue. "
                    "You can use float16 instead by explicitly setting the "
233
234
                    "`dtype` flag in CLI, for example: --dtype=half."
                )
235
236
237
238

    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

    @classmethod
    def insert_blocks_to_device(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from src_cache to dst_cache on XPU."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)

    @classmethod
    def swap_out_blocks_to_host(
        cls,
        src_cache: torch.Tensor,
        dst_cache: torch.Tensor,
        src_block_indices: torch.Tensor,
        dst_block_indices: torch.Tensor,
    ) -> None:
        """Copy blocks from XPU to host (CPU)."""
        _src_cache = src_cache[:, src_block_indices]
        dst_cache[:, dst_block_indices] = _src_cache.cpu()